Skip to content

Remove device to host synchronizations from repeat_interleave and tail_slack#2440

Open
rthekini-aws wants to merge 1 commit intopytorch:mainfrom
rthekini-aws:remove-repeat-interleave-d2h-sync
Open

Remove device to host synchronizations from repeat_interleave and tail_slack#2440
rthekini-aws wants to merge 1 commit intopytorch:mainfrom
rthekini-aws:remove-repeat-interleave-d2h-sync

Conversation

@rthekini-aws
Copy link
Contributor

The current approach for the bias in expert gate/up/down projection uses repeat_interleave which produces a dynamic shape and then uses tail_slack to pad to a static shape. This incurs multiple device to host synchronizations: repeat_interleave without the output_size parameter and the .item call from int(offsets[-1]). Specifically, the repeat_interleave and tail_slack output allocation size both depend on the data in num_tokens_per_expert, but when they are concatenated the output has a statically known shape.

We can solve this problem by reordering operations slightly. In particular, we can pad first and then run repeat_interleave with the output_size parameter to directly produce a tensor of static shape with the correct amount of padding without relying on multiple device to host sync. This should be mathematically equivalent.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 25, 2026
@rthekini-aws
Copy link
Contributor Author

Some runs with GPT-OSS 20B and fixed random seed:

Without fix, run 1
[rank0]:[titan] 2026-02-25 22:35:58,660 - root - INFO - step:  1  loss: 12.78767  grad_norm:  9.8242  memory: 43.51GiB(31.12%)  tps: 400  tflops: 9.63  mfu: 0.97%
[rank0]:[titan] 2026-02-25 22:35:58,660 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2026-02-25 22:36:00,112 - root - INFO - step:  2  loss: 10.84669  grad_norm: 31.3818  memory: 57.65GiB(41.24%)  tps: 2,822  tflops: 67.91  mfu: 6.87%
[rank0]:[titan] 2026-02-25 22:36:01,058 - root - INFO - step:  3  loss:  7.94056  grad_norm: 39.9781  memory: 57.65GiB(41.24%)  tps: 4,334  tflops: 104.29  mfu: 10.55%
[rank0]:[titan] 2026-02-25 22:36:02,003 - root - INFO - step:  4  loss:  6.71291  grad_norm: 20.8837  memory: 57.65GiB(41.24%)  tps: 4,332  tflops: 104.25  mfu: 10.54%
[rank0]:[titan] 2026-02-25 22:36:02,950 - root - INFO - step:  5  loss:  6.64668  grad_norm: 40.3171  memory: 57.65GiB(41.24%)  tps: 4,327  tflops: 104.13  mfu: 10.53%
[rank0]:/shared/rthekini/torchtitan/.venv/lib/python3.10/site-packages/torch/profiler/profiler.py:217: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
[rank0]:  _warn_once(
[rank0]:[titan] 2026-02-25 22:36:03,899 - root - INFO - step:  6  loss:  6.52776  grad_norm: 25.5808  memory: 57.65GiB(41.24%)  tps: 4,320  tflops: 103.97  mfu: 10.51%
[rank0]:[titan] 2026-02-25 22:36:05,078 - root - INFO - step:  7  loss:  5.95308  grad_norm: 20.0180  memory: 57.65GiB(41.24%)  tps: 3,474  tflops: 83.61  mfu: 8.45%
[rank0]:[titan] 2026-02-25 22:36:06,030 - root - INFO - step:  8  loss:  5.30508  grad_norm: 15.6962  memory: 57.65GiB(41.24%)  tps: 4,307  tflops: 103.64  mfu: 10.48%
[rank0]:[titan] 2026-02-25 22:36:06,990 - root - INFO - step:  9  loss:  4.49603  grad_norm: 12.7977  memory: 57.65GiB(41.24%)  tps: 4,265  tflops: 102.64  mfu: 10.38%
[rank0]:[titan] 2026-02-25 22:36:07,981 - root - INFO - step: 10  loss:  4.22573  grad_norm: 22.8335  memory: 57.65GiB(41.24%)  tps: 4,136  tflops: 99.54  mfu: 10.06%


Without fix, run 2
[rank0]:[titan] 2026-02-25 22:38:00,727 - root - INFO - step:  1  loss: 12.78767  grad_norm:  9.8242  memory: 43.51GiB(31.12%)  tps: 402  tflops: 9.67  mfu: 0.98%
[rank0]:[titan] 2026-02-25 22:38:00,728 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2026-02-25 22:38:02,181 - root - INFO - step:  2  loss: 10.84683  grad_norm: 31.3823  memory: 57.65GiB(41.24%)  tps: 2,819  tflops: 67.84  mfu: 6.86%
[rank0]:[titan] 2026-02-25 22:38:03,125 - root - INFO - step:  3  loss:  7.94018  grad_norm: 39.9738  memory: 57.65GiB(41.24%)  tps: 4,341  tflops: 104.46  mfu: 10.56%
[rank0]:[titan] 2026-02-25 22:38:04,071 - root - INFO - step:  4  loss:  6.71299  grad_norm: 20.8834  memory: 57.65GiB(41.24%)  tps: 4,331  tflops: 104.23  mfu: 10.54%
[rank0]:[titan] 2026-02-25 22:38:05,020 - root - INFO - step:  5  loss:  6.64653  grad_norm: 40.3180  memory: 57.65GiB(41.24%)  tps: 4,316  tflops: 103.86  mfu: 10.50%
[rank0]:/shared/rthekini/torchtitan/.venv/lib/python3.10/site-packages/torch/profiler/profiler.py:217: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
[rank0]:  _warn_once(
[rank0]:[titan] 2026-02-25 22:38:05,970 - root - INFO - step:  6  loss:  6.52058  grad_norm: 25.5831  memory: 57.65GiB(41.24%)  tps: 4,316  tflops: 103.87  mfu: 10.50%
[rank0]:[titan] 2026-02-25 22:38:07,149 - root - INFO - step:  7  loss:  5.94565  grad_norm: 20.0171  memory: 57.65GiB(41.24%)  tps: 3,474  tflops: 83.61  mfu: 8.45%
[rank0]:[titan] 2026-02-25 22:38:08,105 - root - INFO - step:  8  loss:  5.29762  grad_norm: 15.6942  memory: 57.65GiB(41.24%)  tps: 4,289  tflops: 103.23  mfu: 10.44%
[rank0]:[titan] 2026-02-25 22:38:09,062 - root - INFO - step:  9  loss:  4.49220  grad_norm: 12.3711  memory: 57.65GiB(41.24%)  tps: 4,280  tflops: 103.01  mfu: 10.42%
[rank0]:[titan] 2026-02-25 22:38:10,058 - root - INFO - step: 10  loss:  4.27104  grad_norm: 24.2117  memory: 57.65GiB(41.24%)  tps: 4,116  tflops: 99.05  mfu: 10.02%


With fix, run 1
[rank0]:[titan] 2026-02-25 22:30:06,805 - root - INFO - step:  1  loss: 12.78767  grad_norm:  9.8242  memory: 43.51GiB(31.12%)  tps: 408  tflops: 9.82  mfu: 0.99%
[rank0]:[titan] 2026-02-25 22:30:06,805 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2026-02-25 22:30:08,081 - root - INFO - step:  2  loss: 10.84649  grad_norm: 31.3810  memory: 56.97GiB(40.75%)  tps: 3,212  tflops: 77.29  mfu: 7.81%
[rank0]:[titan] 2026-02-25 22:30:08,956 - root - INFO - step:  3  loss:  7.94050  grad_norm: 39.9799  memory: 56.97GiB(40.75%)  tps: 4,683  tflops: 112.71  mfu: 11.40%
[rank0]:[titan] 2026-02-25 22:30:09,821 - root - INFO - step:  4  loss:  6.70554  grad_norm: 20.9773  memory: 56.97GiB(40.75%)  tps: 4,736  tflops: 113.97  mfu: 11.52%
[rank0]:[titan] 2026-02-25 22:30:10,696 - root - INFO - step:  5  loss:  6.65492  grad_norm: 40.4038  memory: 56.97GiB(40.75%)  tps: 4,687  tflops: 112.79  mfu: 11.40%
[rank0]:/shared/rthekini/torchtitan/.venv/lib/python3.10/site-packages/torch/profiler/profiler.py:217: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
[rank0]:  _warn_once(
[rank0]:[titan] 2026-02-25 22:30:11,557 - root - INFO - step:  6  loss:  6.54898  grad_norm: 25.7581  memory: 56.97GiB(40.75%)  tps: 4,757  tflops: 114.48  mfu: 11.58%
[rank0]:[titan] 2026-02-25 22:30:12,655 - root - INFO - step:  7  loss:  5.91954  grad_norm: 19.1997  memory: 56.97GiB(40.75%)  tps: 3,733  tflops: 89.84  mfu: 9.08%
[rank0]:[titan] 2026-02-25 22:30:13,523 - root - INFO - step:  8  loss:  5.26176  grad_norm: 15.5981  memory: 56.97GiB(40.75%)  tps: 4,720  tflops: 113.60  mfu: 11.49%
[rank0]:[titan] 2026-02-25 22:30:14,397 - root - INFO - step:  9  loss:  4.48445  grad_norm: 12.4167  memory: 56.97GiB(40.75%)  tps: 4,691  tflops: 112.89  mfu: 11.41%
[rank0]:[titan] 2026-02-25 22:30:15,270 - root - INFO - step: 10  loss:  4.28407  grad_norm: 24.9686  memory: 56.97GiB(40.75%)  tps: 4,693  tflops: 112.94  mfu: 11.42%


With fix, run 2
[rank0]:[titan] 2026-02-25 22:29:20,952 - root - INFO - step:  1  loss: 12.78767  grad_norm:  9.8242  memory: 43.51GiB(31.12%)  tps: 345  tflops: 8.31  mfu: 0.84%
[rank0]:[titan] 2026-02-25 22:29:20,953 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:/shared/rthekini/torchtitan/torchtitan/distributed/utils.py:395: UserWarning: Set timeout is now only supported for either nccl or gloo.
[rank0]:  torch.distributed.distributed_c10d._set_pg_timeout(timeout, group)
[rank0]:[titan] 2026-02-25 22:29:22,237 - root - INFO - step:  2  loss: 10.84647  grad_norm: 31.3820  memory: 56.97GiB(40.75%)  tps: 3,190  tflops: 76.76  mfu: 7.76%
[rank0]:[titan] 2026-02-25 22:29:23,106 - root - INFO - step:  3  loss:  7.94068  grad_norm: 39.9667  memory: 56.97GiB(40.75%)  tps: 4,720  tflops: 113.60  mfu: 11.49%
[rank0]:[titan] 2026-02-25 22:29:23,963 - root - INFO - step:  4  loss:  6.70955  grad_norm: 20.8877  memory: 56.97GiB(40.75%)  tps: 4,782  tflops: 115.07  mfu: 11.64%
[rank0]:[titan] 2026-02-25 22:29:24,825 - root - INFO - step:  5  loss:  6.64535  grad_norm: 40.3155  memory: 56.97GiB(40.75%)  tps: 4,754  tflops: 114.40  mfu: 11.57%
[rank0]:/shared/rthekini/torchtitan/.venv/lib/python3.10/site-packages/torch/profiler/profiler.py:217: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
[rank0]:  _warn_once(
[rank0]:[titan] 2026-02-25 22:29:25,685 - root - INFO - step:  6  loss:  6.52209  grad_norm: 25.6091  memory: 56.97GiB(40.75%)  tps: 4,764  tflops: 114.64  mfu: 11.59%
[rank0]:[titan] 2026-02-25 22:29:26,801 - root - INFO - step:  7  loss:  5.92642  grad_norm: 19.5263  memory: 56.97GiB(40.75%)  tps: 3,672  tflops: 88.37  mfu: 8.94%
[rank0]:[titan] 2026-02-25 22:29:27,671 - root - INFO - step:  8  loss:  5.27531  grad_norm: 15.7020  memory: 56.97GiB(40.75%)  tps: 4,710  tflops: 113.35  mfu: 11.46%
[rank0]:[titan] 2026-02-25 22:29:28,541 - root - INFO - step:  9  loss:  4.47685  grad_norm: 12.2892  memory: 56.97GiB(40.75%)  tps: 4,714  tflops: 113.45  mfu: 11.47%
[rank0]:[titan] 2026-02-25 22:29:29,416 - root - INFO - step: 10  loss:  4.27445  grad_norm: 24.9381  memory: 56.97GiB(40.75%)  tps: 4,680  tflops: 112.62  mfu: 11.39%

@rthekini-aws
Copy link
Contributor Author

Without fix:
Screenshot 2026-02-25 at 5 34 29 PM
With fix:
Screenshot 2026-02-25 at 5 34 17 PM

@tianyu-l
Copy link
Contributor

Thanks. Could you show numerics? Should we expect numerics to not change if you fix seed and determinism?

Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this fix looks correct to me in math. Can you provide loss with determistic + seed checkpoint (instructions here)

Copy link
Contributor

@acisseJZhong acisseJZhong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM! thanks for showing the before and after improvement.

@rthekini-aws
Copy link
Contributor Author

My earlier results are with random seed fixed but deterministic=False, which is why I included 2 runs of each. For the 20B config I'm running, when I tried deterministic=True, I saw NaNs both without and with my fix. I haven't investigated that particular correctness issue, but earlier I did notice issues with trunc_normal_.

@tianyu-l, @wwwjn Any recommendations on how to unblock here without root causing the pre-existing accuracy issue? I will start by trying a couple of other configs (e.g., debug_model) to see if I can show matching loss curves there instead.

@tianyu-l
Copy link
Contributor

@rthekini-aws ah you are right. In my previous experiment I also observed NaN when deterministic mode is on.

I think it's worth triaging what caused the NaN, because o/w any change would make people nervous.

Could you try replacing the flex attn module to sdpa with causal attention? Mathematically it's not correct anymore, but if NaN is caused by flex attn, then you can test if your change on MoE preserves numerics under deterministic mode.

@rthekini-aws rthekini-aws force-pushed the remove-repeat-interleave-d2h-sync branch from 6789396 to 8531311 Compare February 27, 2026 01:10
@rthekini-aws
Copy link
Contributor Author

I root caused the NaN to uninitialized router bias (which is filled with NaNs in deterministic mode). Here's the PR for that issue: #2450

After fixing the initialization, I'm seeing matching results with and without the repeat_interleave fix.
Without fix

=== seed=0 ===
[rank0]:[titan] 2026-02-27 02:09:52,966 - root - INFO - step:  1  loss:  8.13326  grad_norm:  2.0193  memory: 15.15GiB(10.83%)  tps: 4,145  tflops: 2.80  mfu: 0.28%
[rank0]:[titan] 2026-02-27 02:09:53,839 - root - INFO - step:  2  loss:  6.62090  grad_norm:  2.6836  memory: 15.59GiB(11.15%)  tps: 18,767  tflops: 12.68  mfu: 1.28%
[rank0]:[titan] 2026-02-27 02:09:54,039 - root - INFO - step:  3  loss:  4.99889  grad_norm:  3.1322  memory: 15.59GiB(11.15%)  tps: 82,248  tflops: 55.56  mfu: 5.62%
[rank0]:[titan] 2026-02-27 02:09:54,242 - root - INFO - step:  4  loss:  4.77399  grad_norm:  2.5997  memory: 15.59GiB(11.15%)  tps: 80,724  tflops: 54.53  mfu: 5.51%
[rank0]:[titan] 2026-02-27 02:09:54,448 - root - INFO - step:  5  loss:  4.45365  grad_norm:  2.6487  memory: 15.59GiB(11.15%)  tps: 79,644  tflops: 53.80  mfu: 5.44%
=== seed=42 ===
[rank0]:[titan] 2026-02-27 02:10:21,030 - root - INFO - step:  1  loss:  8.08693  grad_norm:  2.3529  memory: 15.15GiB(10.83%)  tps: 4,200  tflops: 2.84  mfu: 0.29%
[rank0]:[titan] 2026-02-27 02:10:21,911 - root - INFO - step:  2  loss:  6.39597  grad_norm:  2.9923  memory: 15.94GiB(11.40%)  tps: 18,588  tflops: 12.56  mfu: 1.27%
[rank0]:[titan] 2026-02-27 02:10:22,119 - root - INFO - step:  3  loss:  4.89488  grad_norm:  2.9787  memory: 15.94GiB(11.40%)  tps: 79,097  tflops: 53.43  mfu: 5.40%
[rank0]:[titan] 2026-02-27 02:10:22,329 - root - INFO - step:  4  loss:  4.59138  grad_norm:  2.4255  memory: 15.94GiB(11.40%)  tps: 78,106  tflops: 52.76  mfu: 5.33%
[rank0]:[titan] 2026-02-27 02:10:22,541 - root - INFO - step:  5  loss:  4.28487  grad_norm:  2.4017  memory: 15.94GiB(11.40%)  tps: 77,706  tflops: 52.49  mfu: 5.31%
=== seed=123 ===
[rank0]:[titan] 2026-02-27 02:10:50,047 - root - INFO - step:  1  loss:  7.98198  grad_norm:  2.1047  memory: 15.15GiB(10.83%)  tps: 4,126  tflops: 2.79  mfu: 0.28%
[rank0]:[titan] 2026-02-27 02:10:50,919 - root - INFO - step:  2  loss:  6.41391  grad_norm:  2.8769  memory: 15.45GiB(11.05%)  tps: 18,792  tflops: 12.69  mfu: 1.28%
[rank0]:[titan] 2026-02-27 02:10:51,122 - root - INFO - step:  3  loss:  4.95427  grad_norm:  3.0792  memory: 15.45GiB(11.05%)  tps: 80,819  tflops: 54.60  mfu: 5.52%
[rank0]:[titan] 2026-02-27 02:10:51,326 - root - INFO - step:  4  loss:  4.74782  grad_norm:  2.5689  memory: 15.45GiB(11.05%)  tps: 80,524  tflops: 54.40  mfu: 5.50%
[rank0]:[titan] 2026-02-27 02:10:51,533 - root - INFO - step:  5  loss:  4.40844  grad_norm:  2.5907  memory: 15.45GiB(11.05%)  tps: 79,162  tflops: 53.48  mfu: 5.41%
=== seed=999 ===
[rank0]:[titan] 2026-02-27 02:11:18,501 - root - INFO - step:  1  loss:  8.13762  grad_norm:  2.0673  memory: 15.15GiB(10.83%)  tps: 4,184  tflops: 2.83  mfu: 0.29%
[rank0]:[titan] 2026-02-27 02:11:19,390 - root - INFO - step:  2  loss:  6.58548  grad_norm:  2.8759  memory: 15.94GiB(11.40%)  tps: 18,439  tflops: 12.46  mfu: 1.26%
[rank0]:[titan] 2026-02-27 02:11:19,592 - root - INFO - step:  3  loss:  4.87090  grad_norm:  2.6618  memory: 15.94GiB(11.40%)  tps: 81,233  tflops: 54.88  mfu: 5.55%
[rank0]:[titan] 2026-02-27 02:11:19,795 - root - INFO - step:  4  loss:  4.63837  grad_norm:  2.2442  memory: 15.94GiB(11.40%)  tps: 80,948  tflops: 54.68  mfu: 5.53%
[rank0]:[titan] 2026-02-27 02:11:20,000 - root - INFO - step:  5  loss:  4.33038  grad_norm:  2.2374  memory: 15.94GiB(11.40%)  tps: 80,084  tflops: 54.10  mfu: 5.47%

With fix:

=== seed=0 ===
[rank0]:[titan] 2026-02-27 02:06:57,972 - root - INFO - step:  1  loss:  8.13326  grad_norm:  2.0193  memory: 15.17GiB(10.85%)  tps: 4,057  tflops: 2.74  mfu: 0.28%
[rank0]:[titan] 2026-02-27 02:06:58,690 - root - INFO - step:  2  loss:  6.62090  grad_norm:  2.6836  memory: 15.20GiB(10.87%)  tps: 22,830  tflops: 15.42  mfu: 1.56%
[rank0]:[titan] 2026-02-27 02:06:58,888 - root - INFO - step:  3  loss:  4.99889  grad_norm:  3.1322  memory: 15.20GiB(10.87%)  tps: 82,794  tflops: 55.93  mfu: 5.66%
[rank0]:[titan] 2026-02-27 02:06:59,090 - root - INFO - step:  4  loss:  4.77399  grad_norm:  2.5997  memory: 15.20GiB(10.87%)  tps: 81,333  tflops: 54.94  mfu: 5.56%
[rank0]:[titan] 2026-02-27 02:06:59,294 - root - INFO - step:  5  loss:  4.45365  grad_norm:  2.6487  memory: 15.20GiB(10.87%)  tps: 80,341  tflops: 54.27  mfu: 5.49%
=== seed=42 ===
[rank0]:[titan] 2026-02-27 02:07:25,836 - root - INFO - step:  1  loss:  8.08693  grad_norm:  2.3529  memory: 15.17GiB(10.85%)  tps: 4,229  tflops: 2.86  mfu: 0.29%
[rank0]:[titan] 2026-02-27 02:07:26,558 - root - INFO - step:  2  loss:  6.39597  grad_norm:  2.9923  memory: 15.20GiB(10.87%)  tps: 22,702  tflops: 15.34  mfu: 1.55%
[rank0]:[titan] 2026-02-27 02:07:26,763 - root - INFO - step:  3  loss:  4.89488  grad_norm:  2.9787  memory: 15.20GiB(10.87%)  tps: 80,090  tflops: 54.10  mfu: 5.47%
[rank0]:[titan] 2026-02-27 02:07:26,973 - root - INFO - step:  4  loss:  4.59138  grad_norm:  2.4255  memory: 15.20GiB(10.87%)  tps: 78,404  tflops: 52.96  mfu: 5.36%
[rank0]:[titan] 2026-02-27 02:07:27,181 - root - INFO - step:  5  loss:  4.28487  grad_norm:  2.4017  memory: 15.20GiB(10.87%)  tps: 78,751  tflops: 53.20  mfu: 5.38%
=== seed=123 ===
[rank0]:[titan] 2026-02-27 02:07:53,939 - root - INFO - step:  1  loss:  7.98198  grad_norm:  2.1047  memory: 15.17GiB(10.85%)  tps: 4,140  tflops: 2.80  mfu: 0.28%
[rank0]:[titan] 2026-02-27 02:07:54,682 - root - INFO - step:  2  loss:  6.41391  grad_norm:  2.8769  memory: 15.20GiB(10.87%)  tps: 22,063  tflops: 14.90  mfu: 1.51%
[rank0]:[titan] 2026-02-27 02:07:54,885 - root - INFO - step:  3  loss:  4.95427  grad_norm:  3.0792  memory: 15.20GiB(10.87%)  tps: 81,077  tflops: 54.77  mfu: 5.54%
[rank0]:[titan] 2026-02-27 02:07:55,091 - root - INFO - step:  4  loss:  4.74782  grad_norm:  2.5689  memory: 15.20GiB(10.87%)  tps: 79,661  tflops: 53.81  mfu: 5.44%
[rank0]:[titan] 2026-02-27 02:07:55,297 - root - INFO - step:  5  loss:  4.40844  grad_norm:  2.5907  memory: 15.20GiB(10.87%)  tps: 79,480  tflops: 53.69  mfu: 5.43%
=== seed=999 ===
[rank0]:[titan] 2026-02-27 02:08:22,540 - root - INFO - step:  1  loss:  8.13762  grad_norm:  2.0673  memory: 15.17GiB(10.85%)  tps: 4,155  tflops: 2.81  mfu: 0.28%
[rank0]:[titan] 2026-02-27 02:08:23,305 - root - INFO - step:  2  loss:  6.58548  grad_norm:  2.8759  memory: 15.20GiB(10.87%)  tps: 21,446  tflops: 14.49  mfu: 1.46%
[rank0]:[titan] 2026-02-27 02:08:23,505 - root - INFO - step:  3  loss:  4.87090  grad_norm:  2.6618  memory: 15.20GiB(10.87%)  tps: 82,171  tflops: 55.51  mfu: 5.61%
[rank0]:[titan] 2026-02-27 02:08:23,709 - root - INFO - step:  4  loss:  4.63837  grad_norm:  2.2442  memory: 15.20GiB(10.87%)  tps: 80,596  tflops: 54.45  mfu: 5.51%
[rank0]:[titan] 2026-02-27 02:08:23,912 - root - INFO - step:  5  loss:  4.33038  grad_norm:  2.2374  memory: 15.20GiB(10.87%)  tps: 80,765  tflops: 54.56  mfu: 5.52%

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp

@tianyu-l
Copy link
Contributor

maybe we should also change gpt-oss bias init to 0? @rthekini-aws

@rthekini-aws
Copy link
Contributor Author

maybe we should also change gpt-oss bias init to 0? @rthekini-aws

@tianyu-l Are you referring to the fix at #2450? Or are you referring to the fact that q/k/v/o biases are initialized with trunc_normal rather than zeros?

@tianyu-l
Copy link
Contributor

@rthekini-aws I mean gpt-oss attention linear bias and gpt-oss moe bias (https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/gpt_oss/moe.py#L269)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants