Conversation
48b2a11 to
07c0ff4
Compare
|
Need to rebase onto #1776 |
tianyu-l
left a comment
There was a problem hiding this comment.
Looks great in general. Left some comments. May need some rebase on recent & near-future development.
|
Summary of current status: There are some prerequisite PRs:
Once these PRs are landed, I will refactor:
|
…ks but reduces mfu for 20b
fegin
left a comment
There was a problem hiding this comment.
Please address all the comments before landing. I would appreciate that if you add the reason why we cannot do AuxOutput to the code. Thanks!
| n_kv_heads: int = 8 | ||
| sliding_window_size: int = 128 | ||
| attn_mask_type: str = "causal" | ||
| use_flex_attn: bool = True |
There was a problem hiding this comment.
I explicitly leave the parameter here, to be compatible with https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/train.py#L428 here, where we need to call get_attention_masks.
But I added a notes here to prevent user change this flag to false
tianyu-l
left a comment
There was a problem hiding this comment.
I think I found a tricky numerical bug in TP. Maybe we can disable it for now.
| - Up to `window_size - 1` previous tokens | ||
| Args: | ||
| window_size: The maximum number of tokens to attend to (including current token). | ||
| Must be >= 1. A window_size of 1 means attend only to self. |
There was a problem hiding this comment.
need to raise ValueError if user didn't set window_size >= 1
| mlp1_weight = self.mlp1_weight.to_local() | ||
| mlp1_bias = self.mlp1_bias.to_local() | ||
| mlp2_weight = self.mlp2_weight.to_local() | ||
| mlp2_bias = self.mlp2_bias.to_local() |
There was a problem hiding this comment.
This might not be correct.
When they are dtensors, x * mlp2_weight + mlp2_bias will have placements Partial + Replicate, and sharding prop can automatically first make Replicate -> Partial then perform the addition.
However, when we do to_local, DTensor placement info is discarded, so instead of adding mlp2_bias, the net effect will be adding tp_degree * mlp2_bias.
I don't have clean way to solve this. For forward correctness, we can do mlp2_bias / tp_degree to cancel the extra reduction effect, but the backward will have an extra * tp_degree. Can we wrap mlp2_bias / tp_degree in torch.no_grad so the backward doesn't perform * tp_degree?
You can also disable TP / ETP altogether for gpt-oss for now and leave a TODO.
cc @ezyang @fmassa on difficulties of making TP correct in a local tensor region, when there is bias involved.
| self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, hidden_dim * 2))) | ||
| self.mlp1_bias = nn.Parameter(torch.empty((num_experts, hidden_dim * 2))) | ||
| self.mlp2_weight = nn.Parameter(torch.empty((num_experts, hidden_dim, dim))) | ||
| self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim))) |
There was a problem hiding this comment.
This is different from main moe.py, where we init the weight params to have shape (num_experts, out_dim, in_dim) and do transpose before using them. The point is for hardware efficiency (mainly in low-precision case). We also need to change the TP / ETP plans to adapt.
See #1517
| # 2. `self._compiled_flex_attn` is not correct, `self` will be passed in | ||
| # as the first argument, which will cause an error. | ||
| # `FlexAttentionWrapper._compiled_flex_attn` is correct. | ||
| # 3. Used `return_lse` instead of `return_aux` because of easier TP module notation |
There was a problem hiding this comment.
yeah can you explain this?
There was a problem hiding this comment.
This API will be removed in a future release
There was a problem hiding this comment.
Thanks! I also noticed the return_lse is being deprecated, the reason we use it here is we want to use TP annotation to change the lse tensor back to a DTensor with placement Shard(1) (in TP region, it's a plain tensor). https://github.com/pytorch/torchtitan/pull/1754/files#diff-3448dcaf6e8b68f3b66a8e1dd298273de3702f93de406569426cd9e03fd7f97bR222. We can not annotate an AuxOutput() object directly using TP APIs. And because we want to keep model code parallelism-free, we don't want to manually turn AuxOut.lse into a DTensor.
I think an alternative way is to handle it in FlexAttentionWarpper, if this is a better way, I will create another PR to fix.
tianyu-l
left a comment
There was a problem hiding this comment.
LGTM, some minor final comments
| - Up to `window_size - 1` previous tokens | ||
| Args: | ||
| window_size: The maximum number of tokens to attend to (including current token). | ||
| Must be >= 1. A window_size of 1 means attend only to self. |
| self.use_grouped_mm = use_grouped_mm | ||
| self.swiglu_limit = swiglu_limit | ||
|
|
||
| self.mlp1_weight = nn.Parameter(torch.empty((num_experts, hidden_dim * 2, dim))) |
There was a problem hiding this comment.
Please add a comment to indicate which dim is input dim, which is output dim.
Keep developing on top of pytorch#1559. Thanks @KhoomeiK for initial contribution! Initialized by the same seed checkpoint, set seed=0 and deterministic = True. GPT-oss Run 1: dp_shard = 2 <img width="1645" height="291" alt="Screenshot 2025-10-17 at 3 34 20 PM" src="https://github.com/user-attachments/assets/9876555f-7159-42d1-8765-17b62feac22c" /> Run 2: dp_shard = 2, TP degree = 2 (NGPU=4) <img width="1222" height="203" alt="Screenshot 2025-10-21 at 8 25 36 PM" src="https://github.com/user-attachments/assets/0014188a-d989-4157-8705-c3fcbab3cf44" /> Run 3: dp_shard = 2, TP degree =2, EP degree = 2 (NGPU=4) <img width="1222" height="203" alt="Screenshot 2025-10-21 at 8 27 34 PM" src="https://github.com/user-attachments/assets/b4ff5076-8c18-47cb-be06-90cf513bd7df" /> Run 4: dp_shard = 2, TP degree = 2, EP degree = 2, ETP degree = 2 (NGPU=4) <img width="1222" height="254" alt="Screenshot 2025-10-21 at 8 30 41 PM" src="https://github.com/user-attachments/assets/8a50e991-c9f2-4b95-b2cc-709acc98e67c" /> Run 5: dp_shard=2, EP degree = 2 (NGPU=2) <img width="1342" height="210" alt="Screenshot 2025-10-17 at 3 35 41 PM" src="https://github.com/user-attachments/assets/6a14a64d-5b43-4efd-b5d2-ab40e2ede52c" /> --------- Co-authored-by: Rohan Pandey <rohan@periodiclabs.ai>
Keep developing on top of pytorch#1559. Thanks @KhoomeiK for initial contribution! Initialized by the same seed checkpoint, set seed=0 and deterministic = True. GPT-oss Run 1: dp_shard = 2 <img width="1645" height="291" alt="Screenshot 2025-10-17 at 3 34 20 PM" src="https://github.com/user-attachments/assets/9876555f-7159-42d1-8765-17b62feac22c" /> Run 2: dp_shard = 2, TP degree = 2 (NGPU=4) <img width="1222" height="203" alt="Screenshot 2025-10-21 at 8 25 36 PM" src="https://github.com/user-attachments/assets/0014188a-d989-4157-8705-c3fcbab3cf44" /> Run 3: dp_shard = 2, TP degree =2, EP degree = 2 (NGPU=4) <img width="1222" height="203" alt="Screenshot 2025-10-21 at 8 27 34 PM" src="https://github.com/user-attachments/assets/b4ff5076-8c18-47cb-be06-90cf513bd7df" /> Run 4: dp_shard = 2, TP degree = 2, EP degree = 2, ETP degree = 2 (NGPU=4) <img width="1222" height="254" alt="Screenshot 2025-10-21 at 8 30 41 PM" src="https://github.com/user-attachments/assets/8a50e991-c9f2-4b95-b2cc-709acc98e67c" /> Run 5: dp_shard=2, EP degree = 2 (NGPU=2) <img width="1342" height="210" alt="Screenshot 2025-10-17 at 3 35 41 PM" src="https://github.com/user-attachments/assets/6a14a64d-5b43-4efd-b5d2-ab40e2ede52c" /> --------- Co-authored-by: Rohan Pandey <rohan@periodiclabs.ai>
Keep developing on top of pytorch#1559. Thanks @KhoomeiK for initial contribution! Initialized by the same seed checkpoint, set seed=0 and deterministic = True. GPT-oss Run 1: dp_shard = 2 <img width="1645" height="291" alt="Screenshot 2025-10-17 at 3 34 20 PM" src="https://github.com/user-attachments/assets/9876555f-7159-42d1-8765-17b62feac22c" /> Run 2: dp_shard = 2, TP degree = 2 (NGPU=4) <img width="1222" height="203" alt="Screenshot 2025-10-21 at 8 25 36 PM" src="https://github.com/user-attachments/assets/0014188a-d989-4157-8705-c3fcbab3cf44" /> Run 3: dp_shard = 2, TP degree =2, EP degree = 2 (NGPU=4) <img width="1222" height="203" alt="Screenshot 2025-10-21 at 8 27 34 PM" src="https://github.com/user-attachments/assets/b4ff5076-8c18-47cb-be06-90cf513bd7df" /> Run 4: dp_shard = 2, TP degree = 2, EP degree = 2, ETP degree = 2 (NGPU=4) <img width="1222" height="254" alt="Screenshot 2025-10-21 at 8 30 41 PM" src="https://github.com/user-attachments/assets/8a50e991-c9f2-4b95-b2cc-709acc98e67c" /> Run 5: dp_shard=2, EP degree = 2 (NGPU=2) <img width="1342" height="210" alt="Screenshot 2025-10-17 at 3 35 41 PM" src="https://github.com/user-attachments/assets/6a14a64d-5b43-4efd-b5d2-ab40e2ede52c" /> --------- Co-authored-by: Rohan Pandey <rohan@periodiclabs.ai>
Keep developing on top of #1559. Thanks @KhoomeiK for initial contribution!
Initialized by the same seed checkpoint, set seed=0 and deterministic = True.
GPT-oss

Run 1: dp_shard = 2
Run 2: dp_shard = 2, TP degree = 2 (NGPU=4)

Run 3: dp_shard = 2, TP degree =2, EP degree = 2 (NGPU=4)

Run 4: dp_shard = 2, TP degree = 2, EP degree = 2, ETP degree = 2 (NGPU=4)

Run 5: dp_shard=2, EP degree = 2 (NGPU=2)
