Skip to content

added custom trunc_normal#2342

Merged
tianyu-l merged 4 commits intopytorch:mainfrom
francesco-bertolotti:f14-trunc-normal
Feb 9, 2026
Merged

added custom trunc_normal#2342
tianyu-l merged 4 commits intopytorch:mainfrom
francesco-bertolotti:f14-trunc-normal

Conversation

@francesco-bertolotti
Copy link
Contributor

This PR addresses #2269

Briefly, there is a numerical instability in torch.nn.init.trunc_normal_ that causes an abnormal number of left bounds to appear in the weights.

I consistently swapped all the usage of torch.nn.init.trunc_normal_ with a custom implementation of trunc_normal_ located in torchtitan/models/utils.py


I have done two other fixes that would require some attention that do not have much to do with the trunc_normal_ but they felt right.

  1. Qwen3 was not initializing the output weights when enable_weight_tying=True. This would mean that embedding initialization would have been used for the output weights, which would have caused the loss to skyrocket past 500.

  2. GPTOSS was calling init_weights in its __init__ method, which does cause some errors.

  3. I did not swap the initialization in the experiments folder, but I can easily add a commit changing those too.


Here I have some debug runs with associated losses:

deepseek llama3 llama4 qwen3_moe gptoss --- Since using debug models do not picture the difference that using the new `trunc_normal_` does, I have a longer run (1000 steps) on the [synth](https://huggingface.co/datasets/PleIAs/SYNTH) dataset using Qwen 4B. It appears that this instability has a compounding factor wrt model size. qwen3-4b

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 8, 2026
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not swap the initialization in the experiments folder, but I can easily add a commit changing those too.

any reason we don't? If not, I think we should.

GPTOSS was calling init_weights in its init method, which does cause some errors.

why it cause errors? we can remove those if so

b=cutoff_factor * final_out_std,
)

# If weight tying is enabled, we don't need to initialize the output layer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3 was not initializing the output weights when enable_weight_tying=True. This would mean that embedding initialization would have been used for the output weights, which would have caused the loss to skyrocket past 500.

To clarify, do you mean we should override embedding init with this output weight init? Why one direction of override is better than the other direction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that the output initialization should have a small std, otherwise we will get high logits and a high loss consequently. If the weight are tied we should prioritize the output layer initialization. At the momement runnning

NGPU=1 CONFIG_FILE="./torchtitan/models/qwen3/train_configs/qwen3_0.6b.toml" ./run_train.sh --training.steps 100 --training.seq_len 256 --compile.no-enable --training.dtype bfloat16 --metrics.enable-tensorboard --job.dump-folder "./outputs/qwen3"

yield

[rank0]:[titan] 2026-02-09 09:03:01,409 - root - INFO - step:  1  loss: 127.88081  grad_norm: 42.2500  memory:  4.48GiB(28.75%)  tps: 267  tflops: 0.75  mfu: 0.24%
[rank0]:[titan] 2026-02-09 09:03:01,696 - root - INFO - step:  2  loss: 124.81187  grad_norm: 66.5000  memory:  6.21GiB(39.92%)  tps: 3,576  tflops: 10.08  mfu: 3.23%
[rank0]:[titan] 2026-02-09 09:03:01,905 - root - INFO - step:  3  loss: 119.61317  grad_norm: 43.0000  memory:  6.21GiB(39.92%)  tps: 4,909  tflops: 13.84  mfu: 4.44%
[rank0]:[titan] 2026-02-09 09:03:02,110 - root - INFO - step:  4  loss: 119.98587  grad_norm: 238.0000  memory:  6.21GiB(39.92%)  tps: 5,006  tflops: 14.11  mfu: 4.52%
[rank0]:[titan] 2026-02-09 09:03:02,314 - root - INFO - step:  5  loss: 115.53140  grad_norm: 29.8750  memory:  6.21GiB(39.92%)  tps: 5,035  tflops: 14.19  mfu: 4.55%

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I guess this is the reason for #1879? @wwwjn
@francesco-bertolotti could you add a comment why we overriding?

@francesco-bertolotti
Copy link
Contributor Author

I didn’t swap the initialization in the experiments folder yet, but I can add a commit for that as well.

any reason we don't? If not, I think we should.

No objection on my end — I just wanted to confirm before doing it.


why it cause errors? we can remove those if so

In gpt_oss, init_weights is called inside __init__, which causes issues when using the custom trunc_normal_. Specifically, it triggers
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

This only happens for gpt_oss. Since no other model calls init_weights in its own __init__, I assumed weight initialization is expected to happen when the model is materialized in train.py.

If keeping init_weights inside __init__ is required, an alternative would be to add @torch.no_grad() to the custom trunc_normal_.

@tianyu-l
Copy link
Contributor

tianyu-l commented Feb 9, 2026

@francesco-bertolotti

If keeping init_weights inside init is required, an alternative would be to add @torch.no_grad() to the custom trunc_normal_.

No it's not required / expected. Let's just remove it.

@wwwjn wwwjn self-assigned this Feb 9, 2026
@francesco-bertolotti
Copy link
Contributor Author

francesco-bertolotti commented Feb 9, 2026

In the last commits I have:

  • changed a bit the documentation for trunc_normal_.
  • introduced a comment for the re-initialization of the output embeddings in qwen3
  • runned pre-commit
  • changed init function in the experiments folder with the custom one.

@tianyu-l tianyu-l merged commit e38d7ab into pytorch:main Feb 9, 2026
14 of 15 checks passed
@tianyu-l
Copy link
Contributor

tianyu-l commented Feb 9, 2026

@francesco-bertolotti could you help remove this line as well?
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/simple_fsdp/deepseek_v3/model.py#L15

tianyu-l pushed a commit that referenced this pull request Feb 10, 2026
removing weight initialization from model's init as per request from
@tianyu-l in
[comment](#2342 (comment))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants