Conversation
Here's the config file I was using. I used the same config file for both the logs in the description``` [job] dump_folder = "./outputs" description = "Gpt-oss debug training" print_config = false[profiling] [metrics] [model] [optimizer] [lr_scheduler] [training] [training.dataloader] [parallelism] [checkpoint] [activation_checkpoint] [compile] [validation] |
|
Thx @chelsea0x3b I think the probelm is we don't know what is "correct" for [bf16 or FP32] reduce. In DeepEP it seems to be FP32 reduce. In Megatron the probs is actually cast back to BF16 right after sigmoid (so gather can also be at least 2x faster): in downstream infer lib, e.g. Sglang both bf16 and fp32 exits. (at some point there is smth about FP8 training/infer consistency, sglang for infer and megatron for training). So its hard to tell what is correct especially when scale up. |
Yeah, agreed that's the real issue.
I found that Megatron does have a recommended
IMO making this configurable is reasonable. |
|
|
|
IIUC in Megatron |
|
oh i see your point, yes, you are right. if we force logits to be fp32 then the reset should also keep in fp32. |
|
Just caught up on this conversation - so should I add some configuration alongside this? |
|
I suggest making it configurable and keeping the fp32 path as the default for BC |
|
actually i have another check on megatron. if we set thats should be equative we do TBH i dont see any reason we dont do this (by default for BF16 path). Rest code then seems make sense to keep on fp32? |
|
@rakkit understanding check: you're saying we should by default have something like class TokenChoiceTopKRouter(nn.Module):
def forward(
self, x: torch.Tensor, expert_bias: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- scores = self.gate(x)
+ scores = torch.mm(x, slef.gate.weight.t(), out_dtype=torch.float32) ? Since we immediately cast |
|
yes @garrett361. we need a "fix" for torchtitan Then optionally offer BF16,. and/Or fully fp32 path. |
this is about the backwards slowness here? Whatever solution is fine to me. Either fixing bmm or moving to elementwise-prod-then-sum. Seems like the former should be strictly faster, though, being a single op. Which is why I introduced the bmm call. |
|
@garrett361 yes, BMM is faster (even with this weird SM80 kernels). IDK its a bug we need to fix or it's expected to like that. |
Oh ok, I misunderstood the other thread; thought bmm was slower, somehow. Which was confusing me because I thought I tested 😅 So is this an accurate summary?
|
@garrett361 yes, and it more or less aligned with magnetron. (its still differ to deepseek's HF inference code, we can comment in code incase someone wants something like "ultimate FP32" version) |
|
Another piece of data: openai's official gpt oss implementation doesn't use f32 at all: https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/model.py#L316 |
|
And llama4 moe router casts right back to input dtype: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L152 |
|
Yeah I saw that gpt-oss code as well @chelsea0x3b. Hadn't looked at the HF llama4 code. The meta llama4 code doesn't do any pre-sigmoid upcast at all. Not much consensus. |
|
IMO since there isn't really consensus I'd lean towards casting back to bfloat16 immediately (as in the PR) because it reduces memory by a decent amount (5% is like 9GB on b200, which is substantial for users not on large GPUs). |
|
@tianyu-l do you have a strong opinion here? |
|
@garrett361 #2448 would sth like this work? the test seems no longer have same issue(failing on inductor) |
|
@acisseJZhong i had this version ages ago and compile works good |
|
lol @rakkit what prevents you from landing that change? shall we revive or maybe @chelsea0x3b could just fix this and land! |
|
@acisseJZhong so times ago i have that and from my ablation (7b-a1b MoE, 64 experts) i don't see significant diff on performance, and mid-2025 cause we need to train smth on 40GB A100 so i decided to keep bf16. (we train another 7b x 1T tokens for bf16-bf16-bf16 and kind still works). but TBH i never think about TP and mixed precision stuff at beginning |
|
sorry so we want to go with the auto cast? just getting a lot of mixed messages and i don't know who to listen to lol. about to go out of town for a week. would love to get this PR merged |
|
@chelsea0x3b can you help test numerics with autocast approach(#2448) in your PR? sorry for all the back and forth 🤣 hope we could land soon. |
|
@acisseJZhong numerics for autocast look good with DP/EP/ETP, PR should be good to go |
torchtitan/models/common/moe/moe.py
Outdated
There was a problem hiding this comment.
i think you can remove the cast to float32? as scores is already fp32. Maybe add a comment saying scores is by default fp32 already. Thanks!
|
@chelsea0x3b pls remove the cast to fp32 for scores, and you can feel free to land it! Would appreciate if you paste the testing result(numerics doesn't change) in PR description! |
|
Im away from my laptop so I dont have access to the logs anymore, but I updated the description and removed the redundant to call |
|
In PR summary, why AC=none results in more memory in bf16 than fp32? Could you include the parallelism config you are using? If using EP, we should turn on load balanced routing to factor out the imbalance https://github.com/pytorch/torchtitan/blob/main/torchtitan/config/configs.py#L387
I don't think the numerics should be the same even if we fix random seed and turn on deterministic mode, because fp32 gate matmul should give us different results than bf16 gate matmul. |
|
@tianyu-l because it kept hitting a ton of cuda memory reallocations. check out the full log i posted for that case, the memory usage varies a lot between each step and the number of reallocations increased each step. i just was taking the mem usage from a single step so it was hard to pick which step for that one bc there wasn't a good "average". the other cases were all much more consistent. and yes the f32 was different numbers, especially for all the different cases, but the losses all followed the same pattern and were very close (within like .1 of each other or closer) |
|
@chelsea0x3b Regarding cuda memory reallocation, I guess you could use a smaller model, e.g. even the debugmodel to showcase. Please fix lint error. |
Original discussion #2225.
Per comments this PR now changes the gate to happen in f32.
Run on 8xb200.
output from runs
Full AC with float32 gate (this PR)
Full AC with bfloat16 gate (main branch)
No AC with bfloat16 gate (main branch)
No AC with float32 gate
Hope this helps!
The numerics don't change with this.