Remove device to host synchronizations from repeat_interleave and tail_slack#2440
Remove device to host synchronizations from repeat_interleave and tail_slack#2440rthekini-aws wants to merge 1 commit intopytorch:mainfrom
Conversation
|
Some runs with GPT-OSS 20B and fixed random seed: |
|
Thanks. Could you show numerics? Should we expect numerics to not change if you fix seed and determinism? |
acisseJZhong
left a comment
There was a problem hiding this comment.
overall LGTM! thanks for showing the before and after improvement.
|
My earlier results are with random seed fixed but @tianyu-l, @wwwjn Any recommendations on how to unblock here without root causing the pre-existing accuracy issue? I will start by trying a couple of other configs (e.g., debug_model) to see if I can show matching loss curves there instead. |
|
@rthekini-aws ah you are right. In my previous experiment I also observed NaN when deterministic mode is on. I think it's worth triaging what caused the NaN, because o/w any change would make people nervous. Could you try replacing the flex attn module to sdpa with causal attention? Mathematically it's not correct anymore, but if NaN is caused by flex attn, then you can test if your change on MoE preserves numerics under deterministic mode. |
6789396 to
8531311
Compare
|
I root caused the NaN to uninitialized router bias (which is filled with NaNs in deterministic mode). Here's the PR for that issue: #2450 After fixing the initialization, I'm seeing matching results with and without the With fix: |
|
maybe we should also change gpt-oss bias init to 0? @rthekini-aws |
@tianyu-l Are you referring to the fix at #2450? Or are you referring to the fact that q/k/v/o biases are initialized with trunc_normal rather than zeros? |
|
@rthekini-aws I mean gpt-oss attention linear bias and gpt-oss moe bias (https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/gpt_oss/moe.py#L269) |


The current approach for the bias in expert gate/up/down projection uses
repeat_interleavewhich produces a dynamic shape and then usestail_slackto pad to a static shape. This incurs multiple device to host synchronizations:repeat_interleavewithout theoutput_sizeparameter and the.itemcall fromint(offsets[-1]). Specifically, therepeat_interleaveandtail_slackoutput allocation size both depend on the data innum_tokens_per_expert, but when they are concatenated the output has a statically known shape.We can solve this problem by reordering operations slightly. In particular, we can pad first and then run
repeat_interleavewith theoutput_sizeparameter to directly produce a tensor of static shape with the correct amount of padding without relying on multiple device to host sync. This should be mathematically equivalent.