Skip to content

Conversation

@kocchop
Copy link
Collaborator

@kocchop kocchop commented Dec 31, 2025

Description

Enables sequence packing for context parallelism with ring strategy using TransformerEngine's DotProductAttention. Includes comprehensive GPU tests for ring attention with packing for sm90+.

  • Currently supports packing only for ring attention
  • Replaced local sequence reordering with TE reorder_causal_load_balancing api
  • Currently the load balancing strategy is automatically picked based on the packing config

Tests

Added a GPU integration test that works for sm90+.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Dec 31, 2025

Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a couple of nits but lgtm

# Handle packing configurations
if self.config.packing and self.config.dataset_type != "synthetic":
if using_context_parallelism and not using_load_balanced_ring_cp:
raise AssertionError("Packing is only supported for load balanced ring attention with context parallelism.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: AssertionError feels weird here to me. Maybe an argumenterror?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

converted to ValueError



def get_reorder_callable(cp_size, shard_mode):
def get_reorder_callable(cp_size, shard_mode, reorder_strategy=0): # 0=DualChunkSwap, 1=Striped
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I read this late at night I imagine you're using an integer here so it's comprehensible by JAX but could this be made into an enum without breaking things (at worse using .value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to enum @richjames0



def shard_reorder_causal_load_balanced(batch, cp_size, shard_mode):
def shard_reorder_causal_load_balanced(batch, cp_size, shard_mode, reorder_strategy=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make reorder_strategy configurable via base.yml?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, ptal @gobbleturk

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment - Terminology/naming is hard, folks have been using the name "striped" to refer to DUAL_CHUNK_SWAP. I guess the string "striped" has to be passed to transformer engine for the other strategy (I would prefer the name "interleaved")...

I highly appreciate your comment with examples of the two strategies to clearly show what they mean in our codebase anyway!

…llelism

- Add ReorderStrategy enum to common_types.py (AUTO, DUAL_CHUNK_SWAP, STRIPED)
- Add context_parallel_reorder_strategy config option
- Update pyconfig, types.py, and train_utils.py to use enum
- Map MaxText enum to TE ReorderStrategy in max_utils.py
"""Reorders the example batch sequences"""
@partial(jax.jit, static_argnums=(1, 2))
def reorder_causal_load_balanced(batch, cp_size, reorder_strategy):
"""Reorders the example batch sequences
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great comment explaining the two of them with examples, thank you!

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tests and great comments illustrating the two reorder strategies!

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just minor comments.

### Determine if we want to use load balance for context parallelism
context_parallel_load_balance: True
context_parallel_strategy: "all_gather" # "all_gather" or "ring"
context_parallel_reorder_strategy: "auto" # "auto", "dual_chunk_swap", or "striped"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this strategy! Could you help add some explanation here for each from reorder_causal_load_balanced?

data_iterator = map(maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), data_iterator)

# Determine load balancing reorder strategy based on whether packing is enabled
if config.context_parallel_reorder_strategy == ReorderStrategy.AUTO:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if AUTO usually gives the best result? Or if there will be an compatible issue if user select ReorderStrategy.STRIPED, but without packing?

Trying to understand if we just provides 2 strategies are good enough.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants