Skip to content

[DTensor Bugfix] Explicitly specify grad_placements in to_local to ensure necessary all reduce takes place #2532

Open
acisseJZhong wants to merge 4 commits intomainfrom
fix_to_local
Open

[DTensor Bugfix] Explicitly specify grad_placements in to_local to ensure necessary all reduce takes place #2532
acisseJZhong wants to merge 4 commits intomainfrom
fix_to_local

Conversation

@acisseJZhong
Copy link
Contributor

@acisseJZhong acisseJZhong commented Mar 9, 2026

Solved #2217 by @alpemreacar

Problem:

x.to_local() was called on a multi-dimensional DTensor (e.g., on a (dp, tp) mesh with placements like (Shard(0), Replicate())). The bare to_local() strips all mesh dimensions and loses gradient placement info for the non-DP (TP) dimensions. For a Replicate() RMSNorm weight on the TP mesh, the backward should produce Partial() gradients (requiring all-reduce), but this information was lost.

Fix:

  1. Extract the non-DP placements from x
  2. Compute the corresponding grad placements — Replicate → Partial(), others stay as-is
  3. Pass the full grad placements (DP dims + non-DP dims) to to_local()

This ensures that in the backward pass, gradients flowing back through to_local() are properly wrapped as a DTensor with Partial() on the TP mesh dimension, which will trigger the necessary all-reduce across TP ranks.

Authored with Claude.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Understanding RMSNorm Gradient Synchronization in Tensor-Parallel LLaMA

1 participant