Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ def estimate_memory(job_config: JobConfig):
# Get the world size
world_size = int(os.environ["WORLD_SIZE"])

if job_config.compile.enable or job_config.parallelism.enable_compiled_autograd:
if job_config.compile.enable:
logger.info("Compile mode is not supported yet. Switching to eager mode.")
job_config.compile.enable = False
job_config.parallelism.enable_compiled_autograd = False

# init fake pg
store = FakeStore()
Expand Down Expand Up @@ -80,10 +79,7 @@ def estimate_memory(job_config: JobConfig):
loss_parallel_enabled = (
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
)
train_context = dist_utils.get_train_context(
loss_parallel_enabled,
job_config.parallelism.enable_compiled_autograd,
)
train_context = dist_utils.get_train_context(loss_parallel_enabled)

# build model (using meta init)
model_args = train_spec.model_args[job_config.model.flavor]
Expand Down
3 changes: 0 additions & 3 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,6 @@ class Parallelism:
1 means disabled.
"""

enable_compiled_autograd: bool = False
"""Enable CompiledAutograd to compile the backward."""

data_parallel_shard_degree: int = -1
"""
The `data_parallel_shard_degree` argument specifies the degree of data
Expand Down
9 changes: 1 addition & 8 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,20 +193,13 @@ def create_context_parallel_ctx(
)


def get_train_context(
enable_loss_parallel: bool, enable_compiled_autograd: bool
) -> Generator[None, None, None]:
def get_train_context(enable_loss_parallel: bool) -> Generator[None, None, None]:
@contextlib.contextmanager
def context(cp_context: Generator[None, None, None] | None = None):
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())

if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)

if cp_context:
stack.enter_context(cp_context)

Expand Down
5 changes: 1 addition & 4 deletions torchtitan/experiments/forge/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,7 @@ def __init__(self, job_config: ForgeJobConfig):
loss_parallel_enabled = (
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
)
self.train_context = dist_utils.get_train_context(
loss_parallel_enabled,
parallelism_config.enable_compiled_autograd,
)
self.train_context = dist_utils.get_train_context(loss_parallel_enabled)
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
parallel_dims,
job_config.training.mixed_precision_param,
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/gpt_oss/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
torch._higher_order_ops.flex_attention,
}


# Adapted from llama4/infra/parallelize.py
def parallelize_gptoss(
model: nn.Module,
Expand Down Expand Up @@ -168,7 +169,6 @@ def parallelize_gptoss(
model,
dp_mesh,
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
)

return model
Expand Down
1 change: 0 additions & 1 deletion torchtitan/experiments/vlm/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def parallelize_vlm(
model,
world_mesh,
enable_compile=job_config.compile.enable,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
)

return model
Expand Down
1 change: 0 additions & 1 deletion torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def parallelize_deepseekv3(
model,
dp_mesh,
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
)

return model
Expand Down
9 changes: 1 addition & 8 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def parallelize_llama(
model,
world_mesh,
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
)

return model
Expand Down Expand Up @@ -322,15 +321,9 @@ def apply_ddp(
model: nn.Module,
dp_mesh: DeviceMesh,
enable_compile: bool,
enable_compiled_autograd: bool,
):
if enable_compile:
if enable_compiled_autograd:
torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward"
)
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
torch._dynamo.config.optimize_ddp = "ddp_optimizer"

replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)

Expand Down
1 change: 0 additions & 1 deletion torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def parallelize_llama(
model,
dp_mesh,
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
)

return model
Expand Down
1 change: 0 additions & 1 deletion torchtitan/models/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def parallelize_qwen3(
model,
world_mesh,
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
)

# Enable weight tying after applying parallelisms
Expand Down
5 changes: 1 addition & 4 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,7 @@ def __init__(self, job_config: JobConfig):
loss_parallel_enabled = (
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
)
self.train_context = dist_utils.get_train_context(
loss_parallel_enabled,
parallelism_config.enable_compiled_autograd,
)
self.train_context = dist_utils.get_train_context(loss_parallel_enabled)
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
parallel_dims,
job_config.training.mixed_precision_param,
Expand Down
Loading