Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
740beac
[torchtitan][replicate] experimenting new replicate integration with …
anshul-si Sep 15, 2025
82ccb85
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Sep 16, 2025
b946784
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Sep 16, 2025
23421c2
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Sep 23, 2025
25632be
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Sep 23, 2025
c383089
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Sep 29, 2025
0951da7
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Sep 30, 2025
ba2a3fc
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Nov 4, 2025
6364c23
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Nov 5, 2025
b7bd2f0
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Nov 5, 2025
9ea3fee
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Nov 5, 2025
f61c56a
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Nov 5, 2025
92e54e8
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Nov 5, 2025
9effc88
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Nov 6, 2025
de0a0bc
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Feb 6, 2026
f6816e2
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Feb 9, 2026
03debc5
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Feb 11, 2026
096ce87
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Feb 11, 2026
dfebee0
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Feb 11, 2026
688c5a5
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Feb 11, 2026
dd30603
Update on "[torchtitan][replicate] experimenting new replicate integr…
anshul-si Feb 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchtitan/distributed/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def _build_mesh_without_ep(self) -> DeviceMesh:
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
):
if d > 1:
# Include dp_shard dimension even if it equals 1 when replicate > 1
# to make device_mesh compatible with replicate function
if d > 1 or (name == "dp_shard" and self.dp_replicate > 1):
dims.append(d)
names.append(name)

Expand Down
9 changes: 5 additions & 4 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,15 @@ def parallelize_llama(
if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
dp_mesh = world_mesh
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
apply_ddp(
model,
dp_mesh,
world_mesh[tuple(dp_mesh_dim_names)],
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
cpu_offload=job_config.training.enable_cpu_offload,
)

return model
Expand Down
8 changes: 5 additions & 3 deletions torchtitan/experiments/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,15 @@ def parallelize_qwen3(
if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
apply_ddp(
model,
world_mesh,
world_mesh[tuple(dp_mesh_dim_names)],
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
cpu_offload=job_config.training.enable_cpu_offload,
)

# Enable weight tying after applying parallelisms
Expand Down
10 changes: 6 additions & 4 deletions torchtitan/experiments/vlm/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,15 @@ def parallelize_vlm(
if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
apply_ddp(
model,
world_mesh,
enable_compile=job_config.compile.enable,
world_mesh[tuple(dp_mesh_dim_names)],
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
cpu_offload=job_config.training.enable_cpu_offload,
)

return model
Expand Down
9 changes: 5 additions & 4 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,15 @@ def parallelize_deepseekv3(
if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
dp_mesh = world_mesh
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
apply_ddp(
model,
dp_mesh,
world_mesh[tuple(dp_mesh_dim_names)],
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
cpu_offload=job_config.training.enable_cpu_offload,
)

return model
Expand Down
42 changes: 30 additions & 12 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
import torch.nn as nn
from torch.distributed._composable.replicate import replicate
from torch.distributed._composable.replicate_with_fsdp import replicate

from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
Expand Down Expand Up @@ -135,13 +135,15 @@ def parallelize_llama(
if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
apply_ddp(
model,
world_mesh,
world_mesh[tuple(dp_mesh_dim_names)],
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
cpu_offload=job_config.training.enable_cpu_offload,
)

return model
Expand Down Expand Up @@ -317,17 +319,33 @@ def apply_fsdp(
def apply_ddp(
model: nn.Module,
dp_mesh: DeviceMesh,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
enable_compile: bool,
enable_compiled_autograd: bool,
cpu_offload: bool = False,
):
if enable_compile:
if enable_compiled_autograd:
torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward"
)
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
replicate_config = {"device_mesh": dp_mesh, "mp_policy": mp_policy}
if cpu_offload:
replicate_config["offload_policy"] = CPUOffloadPolicy()

replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
if model.tok_embeddings is not None:
replicate(
model.tok_embeddings,
**replicate_config,
)
for layer_id, transformer_block in model.layers.items():
replicate(
transformer_block,
**replicate_config,
)

if model.norm is not None and model.output is not None:
replicate(
[model.norm, model.output],
**replicate_config,
)
replicate(model, **replicate_config)

logger.info("Applied DDP to the model")