-
Notifications
You must be signed in to change notification settings - Fork 734
Description
I am testing tensor parallelism (TP) using a simple FSDP-based implementation of LLaMA-3. I am on commit 9f211ec, using Python 3.10 and a nightly version of PyTorch, as noted in the README.
My understanding is that under TP, model weights are sharded across ranks so that each rank performs a fraction of the total computation. I am able to capture the execution graphs using TorchInductor by running:
TORCH_COMPILE_DEBUG=1 CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 -m torchtitan.train --job.config_file=tp.oml --model.name simple_fsdp.llama3 --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config
I have attached my toml configuration file along with a reference implementation.
As expected, I observe both all_gather and reduce_scatter operations in the forward and backward passes, since each rank is computing only a partial result.
However, I am confused about how the RMSNorm parameters are updated. In the LLaMA implementation, each transformer block processes tensors of shape (B, SeqLen / TP, HiddenDim). As a result, each rank only observes gradients corresponding to B × SeqLen / TP tokens for the RMSNorm parameters. These gradients should therefore be reduced across ranks.
However, when inspecting the generated graphs, I do not see any explicit all_reduce corresponding to the RMSNorm parameter gradients.
Could someone clarify where and how this gradient reduction is performed during execution? Is it fused, implicit, or handled outside of the graph I am inspecting?
Any insight would be appreciated.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status