Skip to content

Understanding RMSNorm Gradient Synchronization in Tensor-Parallel LLaMA #2217

@alpemreacar

Description

@alpemreacar

I am testing tensor parallelism (TP) using a simple FSDP-based implementation of LLaMA-3. I am on commit 9f211ec, using Python 3.10 and a nightly version of PyTorch, as noted in the README.

My understanding is that under TP, model weights are sharded across ranks so that each rank performs a fraction of the total computation. I am able to capture the execution graphs using TorchInductor by running:

TORCH_COMPILE_DEBUG=1 CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 -m torchtitan.train --job.config_file=tp.oml --model.name simple_fsdp.llama3 --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config

I have attached my toml configuration file along with a reference implementation.

As expected, I observe both all_gather and reduce_scatter operations in the forward and backward passes, since each rank is computing only a partial result.

However, I am confused about how the RMSNorm parameters are updated. In the LLaMA implementation, each transformer block processes tensors of shape (B, SeqLen / TP, HiddenDim). As a result, each rank only observes gradients corresponding to B × SeqLen / TP tokens for the RMSNorm parameters. These gradients should therefore be reduced across ranks.

However, when inspecting the generated graphs, I do not see any explicit all_reduce corresponding to the RMSNorm parameter gradients.

Could someone clarify where and how this gradient reduction is performed during execution? Is it fused, implicit, or handled outside of the graph I am inspecting?

Any insight would be appreciated.

tptoml.txt

Metadata

Metadata

Assignees

Type

Projects

Status

In Progress

Relationships

None yet

Development

No branches or pull requests

Issue actions