-
Notifications
You must be signed in to change notification settings - Fork 447
Add Packing Support for Context Parallelism (Ring Attention) #2906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2550157
aaf1d43
4c1d1e6
40c1569
9d04e5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -791,13 +791,50 @@ def reorder_sequence(tensor, cp_size: int, seq_dim: int = 1, to_contiguous: bool | |
| return reordered.reshape(ori_tensor_shape) | ||
|
|
||
|
|
||
| @partial(jax.jit, static_argnums=1) | ||
| def reorder_causal_load_balanced(batch, cp_size): | ||
| """Reorders the example batch sequences""" | ||
| @partial(jax.jit, static_argnums=(1, 2)) | ||
| def reorder_causal_load_balanced(batch, cp_size, reorder_strategy): | ||
| """Reorders the example batch sequences | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great comment explaining the two of them with examples, thank you! |
||
|
|
||
| Args: | ||
| batch: The batch to reorder. | ||
| cp_size: The size of the compute parallelism. | ||
| reorder_strategy: The ReorderStrategy enum value (DUAL_CHUNK_SWAP or STRIPED). | ||
|
|
||
| Returns: | ||
| The reordered batch. | ||
|
|
||
| Reorder Strategy: | ||
| - DUAL_CHUNK_SWAP: This strategy splits each query into two chunks and do the mirror swap between | ||
| GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the | ||
| multiple of 2 * cp_size. | ||
| Examples: | ||
| - Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15]; | ||
| - After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3] | ||
|
|
||
| - STRIPED: This strategy distributes the tokens in a striped (interleaved) manner across | ||
| the sequence. This is currently used for THD load balance. | ||
| Example: Consider 4 GPUs with seqlens=16. | ||
| - Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; ...; GPU3: [12, 13, 14, 15] | ||
| - After reorder: GPU0: [0, 4, 8, 12]; GPU1: [1, 5, 9, 13]; ...; GPU3: [3, 7, 11, 15] | ||
|
|
||
| See: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/attention.py | ||
| """ | ||
| # pylint: disable=import-outside-toplevel | ||
| from transformer_engine.jax.attention import ReorderStrategy as TE_ReorderStrategy | ||
| from transformer_engine.jax.attention import reorder_causal_load_balancing | ||
| from MaxText.common_types import ReorderStrategy | ||
|
|
||
| reorder_strategy_map = { | ||
| ReorderStrategy.DUAL_CHUNK_SWAP: TE_ReorderStrategy.DualChunkSwap, | ||
| ReorderStrategy.STRIPED: TE_ReorderStrategy.Striped, | ||
| } | ||
|
|
||
| return { | ||
| key: reorder_sequence( | ||
| key: reorder_causal_load_balancing( | ||
| value, # Pass each key's value inside batch separately | ||
| reorder_strategy_map[reorder_strategy], | ||
| cp_size=cp_size, | ||
| seq_dim=1, | ||
| ) | ||
| if key | ||
| in ["inputs", "targets", "inputs_position", "targets_position", "inputs_segmentation", "targets_segmentation"] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |
| from MaxText.utils.goodput_utils import GoodputEvent | ||
| from MaxText.utils.goodput_utils import maybe_record_goodput | ||
| from MaxText import model_creation_utils | ||
| from MaxText.common_types import ReorderStrategy | ||
|
|
||
|
|
||
| def create_training_tools(config, model, mesh): | ||
|
|
@@ -186,26 +187,43 @@ def setup_train_loop(config, recorder, devices=None): | |
| with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): | ||
| data_iterator, eval_data_iterator = create_data_iterator(config, mesh) | ||
| rampup_manager = create_rampup_manager(config, checkpoint_manager) | ||
| data_loader = create_dataloader(config, mesh, data_iterator, recorder, rampup_manager) | ||
| context_parallel_size = mesh.shape["context"] | ||
| # Check if context parallelism is being used with sequence packing | ||
| if context_parallel_size > 1 and config.packing and config.dataset_type != "synthetic": | ||
| raise ValueError( | ||
| "Context parallelism cannot be used with sequence packing. " | ||
| "Disable sequence packing (set packing=False). " | ||
| "Context parallelism with packing support will be added soon." | ||
| ) | ||
| # Validate context parallelism with packing configuration | ||
| if context_parallel_size > 1 and config.packing: | ||
| if config.dataset_type == "synthetic": | ||
| raise ValueError( | ||
| "Context parallelism with sequence packing is not supported with synthetic data. " | ||
| "Please disable sequence packing (set packing=False)." | ||
| ) | ||
| if config.context_parallel_strategy != "ring": | ||
| raise ValueError( | ||
| "Context parallelism with 'all_gather' strategy cannot be used with sequence packing. " | ||
| "Please use 'ring' strategy instead." | ||
| ) | ||
|
|
||
| # Apply reordering wrapper to data iterators if context parallelism is enabled | ||
| with jax.set_mesh(mesh): | ||
| if context_parallel_size > 1 and config.context_parallel_load_balance: | ||
| data_iterator = map(maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), data_iterator) | ||
|
|
||
| # Determine load balancing reorder strategy based on whether packing is enabled | ||
| if config.context_parallel_reorder_strategy == ReorderStrategy.AUTO: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wondering if AUTO usually gives the best result? Or if there will be an compatible issue if user select ReorderStrategy.STRIPED, but without packing? Trying to understand if we just provides 2 strategies are good enough. |
||
| reorder_strategy = ReorderStrategy.STRIPED if config.packing else ReorderStrategy.DUAL_CHUNK_SWAP | ||
| else: | ||
| reorder_strategy = config.context_parallel_reorder_strategy | ||
|
|
||
| data_iterator = map( | ||
| maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode, reorder_strategy), | ||
| data_iterator, | ||
| ) | ||
| if eval_data_iterator: | ||
| eval_data_iterator = map( | ||
| maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), | ||
| maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode, reorder_strategy), | ||
| eval_data_iterator, | ||
| ) | ||
|
|
||
| # Create data_loader AFTER reordering wrapper is applied | ||
| data_loader = create_dataloader(config, mesh, data_iterator, recorder, rampup_manager) | ||
|
|
||
| state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( | ||
| model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this strategy! Could you help add some explanation here for each from
reorder_causal_load_balanced?