From 07c72008afd9d41ff41503a806de6074cba547f2 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 27 Oct 2025 13:29:24 -0700 Subject: [PATCH 1/2] Update (base update) [ghstack-poisoned] From a95c857a467a5dbf14ae6def38269031b14e0771 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 27 Oct 2025 13:29:24 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- scripts/estimate/estimation.py | 8 ++------ torchtitan/config/job_config.py | 3 --- torchtitan/distributed/utils.py | 9 +-------- torchtitan/experiments/forge/engine.py | 5 +---- torchtitan/experiments/gpt_oss/infra/parallelize.py | 2 +- torchtitan/experiments/vlm/infra/parallelize.py | 1 - torchtitan/models/deepseek_v3/infra/parallelize.py | 1 - torchtitan/models/llama3/infra/parallelize.py | 9 +-------- torchtitan/models/llama4/infra/parallelize.py | 1 - torchtitan/models/qwen3/infra/parallelize.py | 1 - torchtitan/train.py | 5 +---- 11 files changed, 7 insertions(+), 38 deletions(-) diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index b1f45c4051..e0a752d545 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -33,10 +33,9 @@ def estimate_memory(job_config: JobConfig): # Get the world size world_size = int(os.environ["WORLD_SIZE"]) - if job_config.compile.enable or job_config.parallelism.enable_compiled_autograd: + if job_config.compile.enable: logger.info("Compile mode is not supported yet. Switching to eager mode.") job_config.compile.enable = False - job_config.parallelism.enable_compiled_autograd = False # init fake pg store = FakeStore() @@ -80,10 +79,7 @@ def estimate_memory(job_config: JobConfig): loss_parallel_enabled = ( parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel ) - train_context = dist_utils.get_train_context( - loss_parallel_enabled, - job_config.parallelism.enable_compiled_autograd, - ) + train_context = dist_utils.get_train_context(loss_parallel_enabled) # build model (using meta init) model_args = train_spec.model_args[job_config.model.flavor] diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index f1800d0c51..7fe6802374 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -286,9 +286,6 @@ class Parallelism: 1 means disabled. """ - enable_compiled_autograd: bool = False - """Enable CompiledAutograd to compile the backward.""" - data_parallel_shard_degree: int = -1 """ The `data_parallel_shard_degree` argument specifies the degree of data diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 67eb41280f..93a96a4439 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -193,20 +193,13 @@ def create_context_parallel_ctx( ) -def get_train_context( - enable_loss_parallel: bool, enable_compiled_autograd: bool -) -> Generator[None, None, None]: +def get_train_context(enable_loss_parallel: bool) -> Generator[None, None, None]: @contextlib.contextmanager def context(cp_context: Generator[None, None, None] | None = None): with contextlib.ExitStack() as stack: if enable_loss_parallel: stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) - if enable_compiled_autograd: - stack.enter_context( - torch._dynamo.utils.maybe_enable_compiled_autograd(True) - ) - if cp_context: stack.enter_context(cp_context) diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index f8b1412959..d832c39696 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -233,10 +233,7 @@ def __init__(self, job_config: ForgeJobConfig): loss_parallel_enabled = ( parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel ) - self.train_context = dist_utils.get_train_context( - loss_parallel_enabled, - parallelism_config.enable_compiled_autograd, - ) + self.train_context = dist_utils.get_train_context(loss_parallel_enabled) self.maybe_enable_amp = dist_utils.maybe_enable_amp( parallel_dims, job_config.training.mixed_precision_param, diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 9d538e13a1..7bd00c3525 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -45,6 +45,7 @@ torch._higher_order_ops.flex_attention, } + # Adapted from llama4/infra/parallelize.py def parallelize_gptoss( model: nn.Module, @@ -168,7 +169,6 @@ def parallelize_gptoss( model, dp_mesh, enable_compile=model_compile_enabled, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) return model diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index a8095c7621..6a97e4ece1 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -107,7 +107,6 @@ def parallelize_vlm( model, world_mesh, enable_compile=job_config.compile.enable, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) return model diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 8d13a3f31f..0793820ffd 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -171,7 +171,6 @@ def parallelize_deepseekv3( model, dp_mesh, enable_compile=model_compile_enabled, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) return model diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 4944af569e..86ac3a6dfe 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -143,7 +143,6 @@ def parallelize_llama( model, world_mesh, enable_compile=model_compile_enabled, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) return model @@ -322,15 +321,9 @@ def apply_ddp( model: nn.Module, dp_mesh: DeviceMesh, enable_compile: bool, - enable_compiled_autograd: bool, ): if enable_compile: - if enable_compiled_autograd: - torch._dynamo.config.optimize_ddp = ( - "python_reducer_without_compiled_forward" - ) - else: - torch._dynamo.config.optimize_ddp = "ddp_optimizer" + torch._dynamo.config.optimize_ddp = "ddp_optimizer" replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 1f579ccd04..76a554d2f0 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -178,7 +178,6 @@ def parallelize_llama( model, dp_mesh, enable_compile=model_compile_enabled, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) return model diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 5fa8549e9f..6b8dc3d5a6 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -170,7 +170,6 @@ def parallelize_qwen3( model, world_mesh, enable_compile=model_compile_enabled, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) # Enable weight tying after applying parallelisms diff --git a/torchtitan/train.py b/torchtitan/train.py index 9d118854e1..2efd7931ed 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -320,10 +320,7 @@ def __init__(self, job_config: JobConfig): loss_parallel_enabled = ( parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel ) - self.train_context = dist_utils.get_train_context( - loss_parallel_enabled, - parallelism_config.enable_compiled_autograd, - ) + self.train_context = dist_utils.get_train_context(loss_parallel_enabled) self.maybe_enable_amp = dist_utils.maybe_enable_amp( parallel_dims, job_config.training.mixed_precision_param,