From 29a599b904e62edd188e3bd9f9c96d4c840d40df Mon Sep 17 00:00:00 2001 From: ruisizhang123 Date: Thu, 30 Oct 2025 16:26:53 -0700 Subject: [PATCH] [simplefsdp] fix region ac in zero 2 --- torchtitan/experiments/simple_fsdp/README.md | 4 +- torchtitan/experiments/simple_fsdp/backend.py | 35 +++++-- .../experiments/simple_fsdp/compile_utils.py | 94 +++++++++++++++++++ .../simple_fsdp/deepseek_v3/parallelize.py | 49 +++++----- .../simple_fsdp/llama3/parallelize.py | 33 ++++--- .../experiments/simple_fsdp/simple_fsdp.py | 29 +----- 6 files changed, 171 insertions(+), 73 deletions(-) create mode 100644 torchtitan/experiments/simple_fsdp/compile_utils.py diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md index a49fa8ad5..ea4fb3272 100644 --- a/torchtitan/experiments/simple_fsdp/README.md +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -3,11 +3,13 @@ [![integration and numerics tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml?query=branch%3Amain) [![arXiv](https://img.shields.io/badge/arXiv-2411.00284-b31b1b.svg)](https://arxiv.org/abs/2411.00284) -💡 **Note**: SimpleFSDP's composability with Mixed Precision Training and Tensor Parallel requires updates from latest PyTorch, which can be installed (e.g., for CUDA 12.6) via +💡 **Note 1**: SimpleFSDP's composability with Mixed Precision Training and Tensor Parallel requires updates from latest PyTorch, which can be installed (e.g., for CUDA 12.6) via ```bash pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall ``` +💡 **Note 2**: Some of SimpleFSDP's functionalities (e.g., reshard_after_forward) is implemented with torch.compile. It is always recommended to open compile (`--compile.enable`) to see desired correct functionality. + This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. ### Run SimpleFSDP Training on Llama3 & DeepSeek_v3 diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index 36abe4ad0..4bc654bcd 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -4,20 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Union +from typing import Any import torch +import torch._functorch.config as functorch_config +from .compile_utils import annotate_fsdp_all_gather -def get_compile_backend(backend_name: str) -> Union[str, callable]: + +def get_compile_backend( + backend_name: str, fsdp_reshard_after_forward: bool +) -> callable: # return the compile backends used in SimpleFSDP training # Step1: check if backend_name is inside available torch.compile backends # Step2: check if the backend_name has been registered as a customized backend available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) - if backend_name in available_torch_backend: - return backend_name - if backend_name == "aot_eager_autobucketing": + if backend_name in available_torch_backend: + backend = torch._dynamo.lookup_backend(backend_name) + elif backend_name == "aot_eager_autobucketing": # Perform auto optimization in aten fx-level and execute code in aot_eager backend # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend @@ -46,4 +51,22 @@ def aten_autobucketing_reordering_pass( else: raise AssertionError(f"Unsupported customized backend: {backend_name}") - return backend + def joint_ac_pass( + gm: torch.fx.GraphModule, example_inputs: Any + ) -> torch.fx.GraphModule: + # this pass implements simplefsdp's fsdp_reshard_after_forward behavior + # when fsdp_reshard_after_forward set to True, it will annotate simple_fsdp AG + # to CheckpointPolicy.MUST_RECOMPUTE. + # when fsdp_reshard_after_forward set to False, it will annotate simple_fsdp AG + # to CheckpointPolicy.MUST_SAVE. + gm = annotate_fsdp_all_gather(gm, fsdp_reshard_after_forward) + gm.recompile() + return gm + + def simple_fsdp_custom_pass(*args, **kwargs): + # the ac pass has to operate in a joint graph before partitioner for ac + # annotation to take into effect. + with functorch_config.patch("joint_custom_pass", joint_ac_pass): + return backend(*args, **kwargs) + + return simple_fsdp_custom_pass diff --git a/torchtitan/experiments/simple_fsdp/compile_utils.py b/torchtitan/experiments/simple_fsdp/compile_utils.py new file mode 100644 index 000000000..273960ddf --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/compile_utils.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import torch +from torch.utils.checkpoint import CheckpointPolicy + + +def is_graph_input(node: torch.fx.Node) -> bool: + return node.op == "placeholder" + + +def is_wait_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.wait_tensor.default + ) + + +def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default + ) + + +def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool: + """ + Returns True if the node is a wait_tensor node that is the result of an all_gather + that can be arbitrarily prefetched, i.e., if all its recursive inputs are + single-input operators that leads to a graph input. + """ + if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): + n: torch.fx.Node = node.all_input_nodes[0] + while len(n.all_input_nodes) == 1: + if is_graph_input(n.all_input_nodes[0]): + return True + n = n.all_input_nodes[0] + return False + + +def annotate_fsdp_all_gather( + gm: torch.fx.GraphModule, reshard_after_forward: bool +) -> None: + """ + Force recompute all_gather nodes from simple fsdp in the graph. + This pass should be added in torch._inductor.config.joint_custom_post_pass + """ + graph = gm.graph + + def force_recompute_node(node): + if reshard_after_forward: + node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + else: + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + # ac_graph_id is used in the partitioner to decide + # if two nodes which have AC applied come from a different + # AC regions. This is needed because nodes in the boundary + # of two AC regions are marked as MUST_SAVE. In our case + # we just add a large value of ac_graph_id so that + # all nodes we tag for recomputation do indeed get recomputed + # and are not influenced by other nodes in the graph with + # nearby ac_graph_id values + node.meta["ac_graph_id"] = 1000 + + # Make all-gather nodes (and related nodes) recomputable, to circumvent + # https://github.com/pytorch/pytorch/issues/136433 + for node in graph.nodes: + if is_wait_tensor_from_fsdp(node): + ag_node = node.args[0] + force_recompute_node(ag_node) # all_gather + force_recompute_node(node) # wait_tensor + # Force-recompute slice that comes after wait + for user in node.users: + if ( + user.op == "call_function" + and user.target == torch.ops.aten.slice.Tensor + ): + force_recompute_node(user) + # Force-recompute potential dtype casts from all_gather + if ( + ag_node.all_input_nodes[0].op == "call_function" + and ag_node.args[0].target + == torch.ops.prims.convert_element_type.default + ): + force_recompute_node(ag_node.all_input_nodes[0]) + + return gm diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index ac6f9bdc9..2ae1c517f 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -10,16 +10,18 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims + +from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.deepseek_v3.infra.parallelize import ( - apply_ac, apply_moe_ep_tp, apply_non_moe_tp, ) from torchtitan.tools.logging import logger -from ..simple_fsdp import data_parallel, MixedPrecisionPolicy +from ..backend import get_compile_backend +from ..simple_fsdp import data_parallel, MixedPrecisionPolicy # Adapted from llama4/infra/parallelize.py def parallelize_deepseekv3( @@ -91,20 +93,6 @@ def parallelize_deepseekv3( reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) - match job_config.parallelism.fsdp_reshard_after_forward: - case "always": - reshard_after_forward = True - case "never": - reshard_after_forward = False - case "default": - # For PP, by default do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = not parallel_dims.pp_enabled - case _: - raise ValueError( - f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." - ) - # apply data parallel dp_mesh: DeviceMesh | None = None if ( @@ -155,9 +143,7 @@ def parallelize_deepseekv3( transformer_block.moe.experts, dp_mod_ep_mesh, dp_mode, - ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, shard_dim=experts_shard_dim, reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -166,9 +152,7 @@ def parallelize_deepseekv3( model, dp_mesh, dp_mode, - ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, ) logger.info( @@ -178,6 +162,29 @@ def parallelize_deepseekv3( if job_config.compile.enable: torch._inductor.config.reorder_for_peak_memory = False torch._dynamo.config.capture_scalar_outputs = True - model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True) + + match job_config.parallelism.fsdp_reshard_after_forward: + case "always": + fsdp_reshard_after_forward = True + case "never": + fsdp_reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + fsdp_reshard_after_forward = not parallel_dims.pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." + ) + + backend = ( + getattr(job_config.compile, "model_backend_override", None) + or job_config.compile.backend + ) + model = torch.compile( + model, + backend=get_compile_backend(backend, fsdp_reshard_after_forward), + fullgraph=True, + ) return model diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index d61e74a5d..1d8bfc500 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -112,27 +112,11 @@ def parallelize_llama( reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) - match job_config.parallelism.fsdp_reshard_after_forward: - case "always": - reshard_after_forward = True - case "never": - reshard_after_forward = False - case "default": - # For PP, by default do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = not parallel_dims.pp_enabled - case _: - raise ValueError( - f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." - ) - model = data_parallel( model, parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], mode=dp_mode, - ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, ) logger.info( "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode @@ -140,13 +124,28 @@ def parallelize_llama( if job_config.compile.enable and "model" in job_config.compile.components: torch._inductor.config.reorder_for_peak_memory = False + + match job_config.parallelism.fsdp_reshard_after_forward: + case "always": + fsdp_reshard_after_forward = True + case "never": + fsdp_reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + fsdp_reshard_after_forward = not parallel_dims.pp_enabled + case _: + raise ValueError( + f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." + ) + backend = ( getattr(job_config.compile, "model_backend_override", None) or job_config.compile.backend ) model = torch.compile( model, - backend=get_compile_backend(backend), + backend=get_compile_backend(backend, fsdp_reshard_after_forward), fullgraph=True, ) diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 737b6d3ec..4b0827748 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -23,7 +23,6 @@ from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor.placement_types import _StridedShard, Placement from torch.utils.checkpoint import ( - checkpoint, CheckpointPolicy, create_selective_checkpoint_contexts, ) @@ -208,9 +207,7 @@ def __init__( device_mesh, param_sharding, mode, - regional_ac, mp_policy, - reshard_after_forward, reduction_divide_factor, ): super().__init__() @@ -225,11 +222,9 @@ def __init__( if reduction_divide_factor is not None else Partial(reduce_op="avg") ] * self.device_mesh.ndim - self.regional_ac = regional_ac mp_policy = mp_policy or MixedPrecisionPolicy() self.param_dtype = mp_policy.param_dtype self.reduce_dtype = mp_policy.reduce_dtype - self.reshard_after_forward = reshard_after_forward def replicate_compute(self, x: DTensor) -> torch.Tensor: # data parallel runtime replicate parameters and do local compute @@ -292,21 +287,7 @@ def forward(self, x: DTensor) -> torch.Tensor: if not _active_parametrization: return x - if ( - self.regional_ac - and self.mode in ("fully_shard", "hybrid_shard") - and self.reshard_after_forward - ): - # apply checkpointing to implement reshard_after_forward - output = checkpoint( - self.replicate_compute, - x, - use_reentrant=False, - context_fn=fsdp_policy, - ) - else: - output = self.replicate_compute(x) - + output = self.replicate_compute(x) return output @@ -314,9 +295,7 @@ def data_parallel( model: nn.Module, device_mesh: DeviceMesh, mode: str = "replicate", - ac_mode: str = "none", mp_policy: MixedPrecisionPolicy | None = None, - reshard_after_forward: bool = True, shard_dim: int = 0, reduction_divide_factor: float | None = None, ): @@ -335,9 +314,6 @@ def data_parallel( modules = list(model.modules()) - # apply regional ac (with fsdp_policy) if no global ac is to be applied - regional_ac = ac_mode == "none" - for mod in modules: params_dict = dict(mod.named_parameters(recurse=False)) # we shouldn't apply data parallel to the modules that are already @@ -366,7 +342,6 @@ def data_parallel( # device_mesh, # param_sharding, # mode, - # regional_ac, # mp_policy=mp_policy, # ), # unsafe=True, @@ -379,9 +354,7 @@ def data_parallel( device_mesh, param_sharding, mode, - regional_ac, mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, reduction_divide_factor=reduction_divide_factor, ), )