From 4899761cba54ada75b02919181e5670b0ed24c20 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Thu, 12 Feb 2026 16:37:11 -0800 Subject: [PATCH] add graph based ac --- .../compiler_toolkit/common_utils.py | 16 +++ .../deepseek_v3/parallelize.py | 3 + .../compiler_toolkit/llama3/parallelize.py | 3 + .../experiments/compiler_toolkit/passes.py | 88 ++++++++++++ .../compiler_toolkit/tests/test_passes.py | 132 +++++++++++++++++- 5 files changed, 240 insertions(+), 2 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index 8c8ed95e94..4ddada3e24 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -104,3 +104,19 @@ def create_extra_fsdp_pg(parallel_dims: ParallelDims) -> None: def get_extra_fsdp_pg_name(original_pg_name: str) -> str | None: """Look up the extra PG name for a given original FSDP PG name.""" return _EXTRA_FSDP_PG_REGISTRY.get(original_pg_name) + + +def maybe_disable_eager_ac(job_config: JobConfig) -> None: + """Disable eager AC when apply_sac graph pass is enabled. + + When apply_sac is used as a joint graph pass, eager activation checkpointing + must be disabled to avoid double-checkpointing. This must be called before + the model parallelization step that applies eager AC. + """ + joint_pass_names = getattr(job_config.compile, "joint_passes", []) + if "apply_sac" in joint_pass_names: + if job_config.activation_checkpoint.mode != "none": + logger.info( + "apply_sac graph pass is enabled, overriding eager AC mode to none" + ) + job_config.activation_checkpoint.mode = "none" diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 3207676b7e..e39cab1d8f 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -15,6 +15,7 @@ from torchtitan.distributed import ParallelDims from torchtitan.experiments.compiler_toolkit.common_utils import ( disable_compile, + maybe_disable_eager_ac, parallelize_inputs, register_blockmask_pytree_node, ) @@ -72,6 +73,8 @@ def parallelize_deepseekv3( register_blockmask_pytree_node() + maybe_disable_eager_ac(job_config) + # Disable torch.compile over the model in the compiler toolkit style workflow with disable_compile(job_config): model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config) diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 510f94aa97..c6a150149c 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -14,6 +14,7 @@ from torchtitan.distributed import ParallelDims from torchtitan.experiments.compiler_toolkit.common_utils import ( disable_compile, + maybe_disable_eager_ac, parallelize_inputs, register_blockmask_pytree_node, ) @@ -59,6 +60,8 @@ def parallelize_llama( register_blockmask_pytree_node() + maybe_disable_eager_ac(job_config) + # Disable torch.compile over the model in the compiler toolkit style workflow with disable_compile(job_config): model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config) diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 9cea91bc12..3fd14bbf02 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -15,6 +15,7 @@ - Compiler passes: Applied to the partitioned forward/backward graphs """ +import operator from typing import Any, Sequence import torch @@ -27,6 +28,7 @@ from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torch.fx.passes.regional_inductor import regional_inductor +from torch.utils.checkpoint import CheckpointPolicy from torchtitan.experiments.compiler_toolkit.cudagraph import ( CUDAGraphWrapper, get_static_input_indices, @@ -103,6 +105,91 @@ def validate_flex_attn_annotation_pass( return gm +# Default set of ops whose outputs should be saved (not recomputed) during +# activation checkpointing. These are compute-intensive or communication ops +# where recomputation is expensive. +DEFAULT_SAC_SAVE_OPS = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, + torch.ops.torch_attn._varlen_attn, + torch._higher_order_ops.inductor_compiled_code, +} + + +def apply_sac_pass( + gm: torch.fx.GraphModule, + op_list_to_save: set | None = None, +) -> torch.fx.GraphModule: + """ + Apply selective activation checkpointing on the joint graph. + + This pass iterates over all call_function nodes in the joint graph and annotates + each with a CheckpointPolicy. Ops in ``op_list_to_save`` are marked MUST_SAVE + (their outputs are kept as activations for the backward pass), while all other + ops are marked PREFER_RECOMPUTE (their outputs may be discarded and recomputed + during backward). + + To reduce memory further, every second ``mm`` op is marked PREFER_RECOMPUTE + instead of MUST_SAVE, matching the behavior of the eager selective AC policy + in ``torchtitan.distributed.activation_checkpoint``. + + The annotations are later consumed by the min-cut partitioner + (``min_cut_rematerialization_partition``) to split the joint graph into separate + forward and backward graphs. + + Usage: set ``--compile.joint_passes apply_sac``. + + Args: + gm: The joint forward-backward graph module + op_list_to_save: Set of op targets whose outputs should be saved. + Defaults to DEFAULT_SAC_SAVE_OPS if None. + + Returns: + The annotated graph module + """ + if op_list_to_save is None: + op_list_to_save = DEFAULT_SAC_SAVE_OPS + + nodes = list(gm.graph.nodes) + output_node = nodes[-1].all_input_nodes[0] + mm_count = 0 + + for node in nodes: + if node.op != "call_function" or node.target is operator.getitem: + continue + + node.meta["ac_graph_id"] = 0 + + if node.target is torch.ops.aten.mm.default: + mm_count += 1 + # Save every odd mm, recompute every even mm + if mm_count % 2 == 0: + node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE + else: + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + elif node.target in op_list_to_save: + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + else: + node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE + + if node is output_node: + break + + gm.recompile() + logger.info( + "Applied selective activation checkpointing (SAC) graph pass " + f"({mm_count} mm ops found, {mm_count - mm_count // 2} saved)" + ) + return gm + + # Apply activation checkpointing on joint graph before partitioner def fsdp_reshard_after_fwd_pass( gm: torch.fx.GraphModule, reshard_after_forward: bool @@ -318,4 +405,5 @@ def reassign_to_pg_pass( "inductor_decomposition": inductor_decomposition_pass, "fsdp_reshard_after_fwd": fsdp_reshard_after_fwd_pass, "validate_flex_attn_annotation": validate_flex_attn_annotation_pass, + "apply_sac": apply_sac_pass, } diff --git a/torchtitan/experiments/compiler_toolkit/tests/test_passes.py b/torchtitan/experiments/compiler_toolkit/tests/test_passes.py index 9fe7f74f74..90bf034c51 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/test_passes.py +++ b/torchtitan/experiments/compiler_toolkit/tests/test_passes.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator + import torch import torch.nn as nn from torch._functorch.aot_autograd import aot_compile_joint_with_descriptors @@ -12,11 +14,15 @@ is_all_gather_into_tensor as is_all_gather, ) from torch.testing._internal.common_fsdp import FSDPTest -from torch.utils.checkpoint import checkpoint +from torch.testing._internal.common_utils import TestCase +from torch.utils.checkpoint import checkpoint, CheckpointPolicy from torchtitan.distributed import ParallelDims from torchtitan.experiments.compiler_toolkit.graph_utils import export_joint -from torchtitan.experiments.compiler_toolkit.passes import reassign_to_pg_pass +from torchtitan.experiments.compiler_toolkit.passes import ( + apply_sac_pass, + reassign_to_pg_pass, +) from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel @@ -200,6 +206,128 @@ def test_reassign_with_extra_pg(self): self.assertEqual(ag_new, ag_before) +class TestApplySACPass(TestCase): + """Unit tests for the apply_sac_pass joint graph pass.""" + + def _build_gm(self, op_targets): + """Build a GraphModule with a chain of call_function nodes. + + Each op in op_targets becomes a call_function node. The graph + structure is: placeholder(x), placeholder(y) -> op1 -> op2 -> ... -> output. + """ + graph = torch.fx.Graph() + x = graph.placeholder("x") + y = graph.placeholder("y") + last = x + for target in op_targets: + if target is operator.getitem: + last = graph.call_function(target, args=(last, 0)) + else: + last = graph.call_function(target, args=(last, y)) + graph.output(last) + return torch.fx.GraphModule(torch.nn.Module(), graph) + + def _get_call_function_nodes(self, gm): + """Return all call_function nodes from the graph.""" + return [n for n in gm.graph.nodes if n.op == "call_function"] + + def test_non_save_ops_marked_recompute(self): + """Ops not in the save list should be marked PREFER_RECOMPUTE.""" + gm = self._build_gm( + [ + torch.ops.aten.add.Tensor, + torch.ops.aten.relu.default, + ] + ) + apply_sac_pass(gm) + for node in self._get_call_function_nodes(gm): + self.assertEqual(node.meta["recompute"], CheckpointPolicy.PREFER_RECOMPUTE) + + def test_save_ops_marked_must_save(self): + """Non-mm ops in the save list should be marked MUST_SAVE.""" + custom_save = {torch.ops.aten.add.Tensor} + gm = self._build_gm([torch.ops.aten.add.Tensor]) + apply_sac_pass(gm, op_list_to_save=custom_save) + nodes = self._get_call_function_nodes(gm) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].meta["recompute"], CheckpointPolicy.MUST_SAVE) + + def test_getitem_nodes_skipped(self): + """operator.getitem nodes should not receive any annotation.""" + gm = self._build_gm( + [ + torch.ops.aten.add.Tensor, + operator.getitem, + torch.ops.aten.relu.default, + ] + ) + apply_sac_pass(gm) + for node in self._get_call_function_nodes(gm): + if node.target is operator.getitem: + self.assertNotIn("recompute", node.meta) + self.assertNotIn("ac_graph_id", node.meta) + + def test_ac_graph_id_set(self): + """All annotated nodes should have ac_graph_id = 0.""" + gm = self._build_gm( + [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mm.default, + torch.ops.aten.relu.default, + ] + ) + apply_sac_pass(gm) + for node in self._get_call_function_nodes(gm): + if node.target is not operator.getitem: + self.assertEqual(node.meta["ac_graph_id"], 0) + + def test_custom_op_list_to_save(self): + """A custom op_list_to_save should override the defaults.""" + custom_save = {torch.ops.aten.relu.default} + gm = self._build_gm( + [ + torch.ops.aten.add.Tensor, + torch.ops.aten.relu.default, + ] + ) + apply_sac_pass(gm, op_list_to_save=custom_save) + policies = { + n.target: n.meta["recompute"] for n in self._get_call_function_nodes(gm) + } + self.assertEqual( + policies[torch.ops.aten.add.Tensor], CheckpointPolicy.PREFER_RECOMPUTE + ) + self.assertEqual( + policies[torch.ops.aten.relu.default], CheckpointPolicy.MUST_SAVE + ) + + def test_mixed_mm_and_save_ops(self): + """Graph with both mm and other save ops are annotated correctly.""" + custom_save = {torch.ops.aten.mm.default, torch.ops.aten.max.default} + gm = self._build_gm( + [ + torch.ops.aten.mm.default, # 1st mm -> MUST_SAVE + torch.ops.aten.max.default, # in save list -> MUST_SAVE + torch.ops.aten.mm.default, # 2nd mm -> PREFER_RECOMPUTE + torch.ops.aten.add.Tensor, # not in save list -> PREFER_RECOMPUTE + torch.ops.aten.mm.default, # 3rd mm -> MUST_SAVE + ] + ) + apply_sac_pass(gm, op_list_to_save=custom_save) + nodes = self._get_call_function_nodes(gm) + expected = [ + (torch.ops.aten.mm.default, CheckpointPolicy.MUST_SAVE), + (torch.ops.aten.max.default, CheckpointPolicy.MUST_SAVE), + (torch.ops.aten.mm.default, CheckpointPolicy.PREFER_RECOMPUTE), + (torch.ops.aten.add.Tensor, CheckpointPolicy.PREFER_RECOMPUTE), + (torch.ops.aten.mm.default, CheckpointPolicy.MUST_SAVE), + ] + self.assertEqual(len(nodes), len(expected)) + for node, (target, policy) in zip(nodes, expected): + self.assertEqual(node.target, target) + self.assertEqual(node.meta["recompute"], policy, f"node {node.name}") + + if __name__ == "__main__": from torch.testing._internal.common_utils import run_tests