[llama4] enable expert parallel on the same device mesh as tp (tp2ep)#1269
[llama4] enable expert parallel on the same device mesh as tp (tp2ep)#1269hann-wang wants to merge 8 commits intopytorch:mainfrom
Conversation
|
Thank you for the PR! I'll take a look. |
tianyu-l
left a comment
There was a problem hiding this comment.
Thank you very much for the PR! I think the idea sounds very interesting.
I have some high-level questions:
- Compared with the PR #731 you refer to, this tp2ep implementation is all-to-all based rather than using all-gather / reduce-scatter. Do you have any idea which is more efficient, assuming both are correct?
- Personally I think the implementation itself is a bit too intrusive to model code, whereas the idea of torchtitan is trying not to do so (https://github.com/pytorch/torchtitan/blob/main/README.md?plain=1#L38). Do you think there is a chance you could make it cleaner?
- Do you have some testing to show that your implementation is correct, e.g. in terms of loss curves compared with training with single-device code?
| "moe": | ||
| PrepareModuleInputOutput( | ||
| input_layouts=(Shard(1), ), | ||
| desired_input_layouts=(Shard(1), ), |
There was a problem hiding this comment.
If I understand correctly, the input to router is sharded. Then this might break the semantics / correctness of the load balancing algorithm, given the update to self.tokens_per_expert is local to each EP rank.
https://github.com/pytorch/torchtitan/pull/1269/files#diff-87cc24d85c768f0b3d1f5c54cca39dc9de52ee20e8f601814c3200722901aee5R293
There was a problem hiding this comment.
Thank you for pointing out this issue. We need an all_reduce across all ep groups.
Fixed in b87aa1e
|

This PR is built on top of the concept introduced in #731.
In this implementation, the input to the MoE module is sharded along the seqlen dimension rather than being replicated. After gathering tokens from different EP ranks using
all_to_all_single_autograd, the output tokens remain sharded along the seqlen dimension.To activate this feature, set
enable_tp2ep = truein the configuration file.cc @tianyu-l