-
Notifications
You must be signed in to change notification settings - Fork 643
[CP] Refactor Context Parallel to use new PyTorch CP APIs #2144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/fegin/53/base
Are you sure you want to change the base?
Conversation
| logger.info("Applied DDP to the model") | ||
|
|
||
|
|
||
| def apply_cp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- should we put this function to distributed/context_parallel.py?
- should we apply this to all models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Yes.
2. Do we actually verify CP for all models? I think llama3 and llama4, yes. I'm thinking to re-enable CP model by model using this refactor chance.
nvm, I added to the core models. I'll leave Flux for another PR. c.c, @wwwjn
| "attention_norm": SequenceParallel(), | ||
| # NOTE: when the fourth argument (positions) is not None, its input layout | ||
| # and desired input layout should be Replicate() | ||
| # and desired input layout is still None as we don't convert freqs_cis to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe should change this after @wwwjn 's PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR #2149 has composability issue when TP + PP is applied, and I'm trying to discuss how to fix. I guess we could also land this PR if it's ready and I could rebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion. Let's land whichever PR ready first.
| input_dict, labels | ||
| ) | ||
| # apply context parallelism if cp is enabled | ||
| # ensure CP handles the separate freqs_cis buffer for each pp stage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after this PR, we seem no longer needing freqs_cis as model input. IIRC we modified the freqs_cis-related model code logic previous to make it model input. Shall we revert those logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now use max_seq_len and reassign max_seq_len during mode_args initialization. I think this doesn't count as a hack. So I only remove the legacy TODO (which I don't think we need anymore).
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2145 * #2144 * __->__ #2143 1. Accept one "." (meaning the current commit) case to simplify the command line. 2. Ignore the untracked files.
**Summary** 1. Refactored CP Dispatching: - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call. - Enables CP dispatcher for SDPA attention type inside apply_cp() 2. New CP Data Sharding Approach: - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API - Uses _HeadTailLoadBalancer for SDPA attention load balancing - FlexAttention CP support deferred to a future PR - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded - The new positions argument allows us to not shard the freqs_cis. Note that this PR require pytorch/pytorch#170200 **Test** ``` -> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal pick 5903566a Improve the loss_compare.sh logic [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK [LOSS_COMPARE] All losses are equal. Assertion passed! [LOSS_COMPARE] ========================================== [LOSS_COMPARE] LOSS COMPARISON ANALYSIS [LOSS_COMPARE] ========================================== [LOSS_COMPARE] Step-by-step loss comparison: [LOSS_COMPARE] Step Baseline Loss Test Loss Difference [LOSS_COMPARE] ---- ------------- --------- ---------- [LOSS_COMPARE] 1 8.1309 8.1309 0.000000 [LOSS_COMPARE] 2 7.8268 7.8268 0.000000 [LOSS_COMPARE] 3 7.2284 7.2284 0.000000 [LOSS_COMPARE] 4 6.4669 6.4669 0.000000 [LOSS_COMPARE] 5 5.4017 5.4017 0.000000 [LOSS_COMPARE] 6 4.7656 4.7656 0.000000 [LOSS_COMPARE] 7 4.3587 4.3587 0.000000 [LOSS_COMPARE] 8 4.0938 4.0938 0.000000 [LOSS_COMPARE] 9 4.4019 4.4019 0.000000 [LOSS_COMPARE] 10 3.7451 3.7451 0.000000 .... [LOSS_COMPARE] 90 2.802 2.802 0.000000 [LOSS_COMPARE] 91 2.7207 2.7207 0.000000 [LOSS_COMPARE] 92 2.7454 2.7454 0.000000 [LOSS_COMPARE] 93 2.6992 2.6992 0.000000 [LOSS_COMPARE] 94 2.743 2.743 0.000000 [LOSS_COMPARE] 95 2.7534 2.7534 0.000000 [LOSS_COMPARE] 96 2.8403 2.8403 0.000000 [LOSS_COMPARE] 97 2.783 2.783 0.000000 [LOSS_COMPARE] 98 3.0892 3.0892 0.000000 [LOSS_COMPARE] 99 2.7905 2.7905 0.000000 [LOSS_COMPARE] 100 2.733 2.733 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Summary statistics: [LOSS_COMPARE] Average baseline loss: 3.1414940000000002 [LOSS_COMPARE] Average test loss: 3.1414940000000002 [LOSS_COMPARE] Average difference: 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Loss comparison complete. No results saved (no output folder specified). ``` **TODO** - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id. [ghstack-poisoned]
**Summary** 1. Refactored CP Dispatching: - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call. - Enables CP dispatcher for SDPA attention type inside apply_cp() 2. New CP Data Sharding Approach: - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API - Uses _HeadTailLoadBalancer for SDPA attention load balancing - FlexAttention CP support deferred to a future PR - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded - The new positions argument allows us to not shard the freqs_cis. Note that this PR require pytorch/pytorch#170200 **Test** ``` -> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal pick 5903566a Improve the loss_compare.sh logic [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK [LOSS_COMPARE] All losses are equal. Assertion passed! [LOSS_COMPARE] ========================================== [LOSS_COMPARE] LOSS COMPARISON ANALYSIS [LOSS_COMPARE] ========================================== [LOSS_COMPARE] Step-by-step loss comparison: [LOSS_COMPARE] Step Baseline Loss Test Loss Difference [LOSS_COMPARE] ---- ------------- --------- ---------- [LOSS_COMPARE] 1 8.1309 8.1309 0.000000 [LOSS_COMPARE] 2 7.8268 7.8268 0.000000 [LOSS_COMPARE] 3 7.2284 7.2284 0.000000 [LOSS_COMPARE] 4 6.4669 6.4669 0.000000 [LOSS_COMPARE] 5 5.4017 5.4017 0.000000 [LOSS_COMPARE] 6 4.7656 4.7656 0.000000 [LOSS_COMPARE] 7 4.3587 4.3587 0.000000 [LOSS_COMPARE] 8 4.0938 4.0938 0.000000 [LOSS_COMPARE] 9 4.4019 4.4019 0.000000 [LOSS_COMPARE] 10 3.7451 3.7451 0.000000 .... [LOSS_COMPARE] 90 2.802 2.802 0.000000 [LOSS_COMPARE] 91 2.7207 2.7207 0.000000 [LOSS_COMPARE] 92 2.7454 2.7454 0.000000 [LOSS_COMPARE] 93 2.6992 2.6992 0.000000 [LOSS_COMPARE] 94 2.743 2.743 0.000000 [LOSS_COMPARE] 95 2.7534 2.7534 0.000000 [LOSS_COMPARE] 96 2.8403 2.8403 0.000000 [LOSS_COMPARE] 97 2.783 2.783 0.000000 [LOSS_COMPARE] 98 3.0892 3.0892 0.000000 [LOSS_COMPARE] 99 2.7905 2.7905 0.000000 [LOSS_COMPARE] 100 2.733 2.733 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Summary statistics: [LOSS_COMPARE] Average baseline loss: 3.1414940000000002 [LOSS_COMPARE] Average test loss: 3.1414940000000002 [LOSS_COMPARE] Average difference: 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Loss comparison complete. No results saved (no output folder specified). ``` **TODO** - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id. [ghstack-poisoned]
| self, | ||
| input_dict: dict[str, torch.Tensor], | ||
| labels: torch.Tensor, | ||
| device: torch.device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just use input_dict["inputs"].device?
| def post_dataloading_process( | ||
| self, | ||
| input_dict: dict[str, torch.Tensor], | ||
| labels: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we should consolidate this into input_dict["labels"]
| A tuple of (inputs, labels, extra_inputs, extra_kwargs) where: | ||
| - inputs: Main input tensor extracted from input_dict["input"]. | ||
| - labels: Target labels (potentially modified by CP sharding). | ||
| - extra_inputs: Dict of auxiliary input tensors (all keys except |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can consolidate this into input_dict as well?
| # extra_kwargs are. | ||
| extra_kwargs: dict[str, Any] = {} | ||
|
|
||
| attn_type = getattr(self.model_args, "attn_type", "sdpa") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we introduce model_args to validator because of this line.
This condition is checked again in get_attention_masks. Can we remove it here? We can return None instead of throw in get_attention_masks when attn_type is sdpa.
The argument is that validator is not supposed to know ModelArgs details.
| return total_norm | ||
|
|
||
|
|
||
| def cp_shard( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this function also go to context_parallel.py? maybe we can consolidate with get_context_parallel_inputs since both are not long.
| from torchtitan.tools.logging import logger | ||
|
|
||
|
|
||
| def apply_cp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name sounds too generic for what it's assuming, e.g. you needed a different one for flux.
Maybe apply_cp_to_transformer_blocks and send in model.layers.values()?
| Args: | ||
| model: The transformer model with layers containing attention modules | ||
| cp_mesh: Device mesh for context parallel dimension | ||
| use_flex_attn: Whether the model uses FlexAttention (True) or SDPA (False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b q: We don't consider varlen + CP here? why is that?
| - Applies to transformer_block.attention.inner_attention for each layer | ||
| """ | ||
| # Apply context parallelism to every transformer block | ||
| # TODO: make seq_sim configurable once the implementation doesn't assume 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # TODO: make seq_sim configurable once the implementation doesn't assume 2 | |
| # TODO: make seq_dim configurable once the implementation doesn't assume 2 |
| else: | ||
| # This is currently required as DTensor dispatcher is not enabled to | ||
| # dispatch SDPA to CP implementation. We don't disable the CP | ||
| # dispatching in TorchTitan as it is not needed. But there is a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is a little bit confusing to me - DTensor dispatcher is not enabled to dispatch SDPA to CP implementation, so we explicitly enable it by calling _enable_context_parallel_dispatcher. But why "we don't disable the CP dispatching in torchtitan"?
| ) | ||
|
|
||
| if parallel_dims.cp_enabled: | ||
| use_flex_attn = attn_type == "flex" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So for now, it attn_type == "varlen" will fall to SDPA branch in apply_cp?
Stack from ghstack (oldest at bottom):
Summary
Note that this PR require pytorch/pytorch#170200
Test
TODO