-
Notifications
You must be signed in to change notification settings - Fork 596
[simplefsdp] fix region ac in zero2-style FSDP #1970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| import torch | ||
| from torch.utils.checkpoint import CheckpointPolicy | ||
|
|
||
|
|
||
| def is_graph_input(node: torch.fx.Node) -> bool: | ||
| return node.op == "placeholder" | ||
|
|
||
|
|
||
| def is_wait_tensor(node: torch.fx.Node) -> bool: | ||
| return ( | ||
| node.op == "call_function" | ||
| and node.target == torch.ops._c10d_functional.wait_tensor.default | ||
| ) | ||
|
|
||
|
|
||
| def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: | ||
| return ( | ||
| node.op == "call_function" | ||
| and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default | ||
| ) | ||
|
|
||
|
|
||
| def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool: | ||
| """ | ||
| Returns True if the node is a wait_tensor node that is the result of an all_gather | ||
| that can be arbitrarily prefetched, i.e., if all its recursive inputs are | ||
| single-input operators that leads to a graph input. | ||
| """ | ||
| if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): | ||
| n: torch.fx.Node = node.all_input_nodes[0] | ||
| while len(n.all_input_nodes) == 1: | ||
| if is_graph_input(n.all_input_nodes[0]): | ||
| return True | ||
| n = n.all_input_nodes[0] | ||
| return False | ||
|
|
||
|
|
||
| def annotate_fsdp_all_gather( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the |
||
| gm: torch.fx.GraphModule, reshard_after_forward: bool | ||
| ) -> None: | ||
| """ | ||
| Force recompute all_gather nodes from simple fsdp in the graph. | ||
| This pass should be added in torch._inductor.config.joint_custom_post_pass | ||
| """ | ||
| graph = gm.graph | ||
|
|
||
| def force_recompute_node(node): | ||
| if reshard_after_forward: | ||
| node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE | ||
| else: | ||
| node.meta["recompute"] = CheckpointPolicy.MUST_SAVE | ||
|
Comment on lines
+58
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does it work with full AC? IIUC these flags are for SAC API. |
||
| # ac_graph_id is used in the partitioner to decide | ||
| # if two nodes which have AC applied come from a different | ||
| # AC regions. This is needed because nodes in the boundary | ||
| # of two AC regions are marked as MUST_SAVE. In our case | ||
| # we just add a large value of ac_graph_id so that | ||
| # all nodes we tag for recomputation do indeed get recomputed | ||
| # and are not influenced by other nodes in the graph with | ||
| # nearby ac_graph_id values | ||
| node.meta["ac_graph_id"] = 0 | ||
|
|
||
| # Make all-gather nodes (and related nodes) recomputable, to circumvent | ||
| # https://github.com/pytorch/pytorch/issues/136433 | ||
| for node in graph.nodes: | ||
| if is_wait_tensor_from_fsdp(node): | ||
| ag_node = node.args[0] | ||
| force_recompute_node(ag_node) # all_gather | ||
| force_recompute_node(node) # wait_tensor | ||
| # Force-recompute slice that comes after wait | ||
| for user in node.users: | ||
| if ( | ||
| user.op == "call_function" | ||
| and user.target == torch.ops.aten.slice.Tensor | ||
| ): | ||
| force_recompute_node(user) | ||
| # Force-recompute potential dtype casts from all_gather | ||
| if ( | ||
| ag_node.all_input_nodes[0].op == "call_function" | ||
| and ag_node.args[0].target | ||
| == torch.ops.prims.convert_element_type.default | ||
| ): | ||
| force_recompute_node(ag_node.all_input_nodes[0]) | ||
|
|
||
| return gm | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4,20 +4,23 @@ | |||||
| # This source code is licensed under the BSD-style license found in the | ||||||
| # LICENSE file in the root directory of this source tree. | ||||||
|
|
||||||
| from typing import Any, Union | ||||||
| from typing import Any | ||||||
|
|
||||||
| import torch | ||||||
| import torch._functorch.config as functorch_config | ||||||
|
|
||||||
| from .activation_checkpoint import annotate_fsdp_all_gather | ||||||
|
|
||||||
| def get_compile_backend(backend_name: str) -> Union[str, callable]: | ||||||
|
|
||||||
| def get_compile_backend(backend_name: str, reshard_after_forward: bool) -> callable: | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since it's not related to replicate part of the API
Suggested change
|
||||||
| # return the compile backends used in SimpleFSDP training | ||||||
| # Step1: check if backend_name is inside available torch.compile backends | ||||||
| # Step2: check if the backend_name has been registered as a customized backend | ||||||
| available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) | ||||||
| if backend_name in available_torch_backend: | ||||||
| return backend_name | ||||||
|
|
||||||
| if backend_name == "aot_eager_autobucketing": | ||||||
| if backend_name in available_torch_backend: | ||||||
| backend = torch._dynamo.lookup_backend(backend_name) | ||||||
| elif backend_name == "aot_eager_autobucketing": | ||||||
| # Perform auto optimization in aten fx-level and execute code in aot_eager backend | ||||||
| # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 | ||||||
| from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend | ||||||
|
|
@@ -46,4 +49,22 @@ def aten_autobucketing_reordering_pass( | |||||
| else: | ||||||
| raise AssertionError(f"Unsupported customized backend: {backend_name}") | ||||||
|
|
||||||
| return backend | ||||||
| def joint_ac_pass( | ||||||
| gm: torch.fx.GraphModule, example_inputs: Any | ||||||
| ) -> torch.fx.GraphModule: | ||||||
| # this pass implements simplefsdp's reshard_after_forward behavior | ||||||
| # when reshard_after_forward set to True, it will annotate simple_fsdp AG | ||||||
| # to CheckpointPolicy.MUST_RECOMPUTE. | ||||||
| # when reshard_after_forward set to False, it will annotate simple_fsdp AG | ||||||
| # to CheckpointPolicy.MUST_SAVE. | ||||||
| gm = annotate_fsdp_all_gather(gm, reshard_after_forward) | ||||||
| gm.recompile() | ||||||
| return gm | ||||||
|
|
||||||
| def simple_fsdp_custom_pass(*args, **kwargs): | ||||||
| # the ac pass has to operate in a joint graph before partitioner for ac | ||||||
| # annotation to take into effect. | ||||||
| with functorch_config.patch("joint_custom_pass", joint_ac_pass): | ||||||
| return backend(*args, **kwargs) | ||||||
|
|
||||||
| return simple_fsdp_custom_pass | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -112,6 +112,20 @@ def parallelize_llama( | |
| reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], | ||
| ) | ||
|
|
||
| model = data_parallel( | ||
| model, | ||
| parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], | ||
| mode=dp_mode, | ||
| ac_mode=job_config.activation_checkpoint.mode, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can remove |
||
| mp_policy=mp_policy, | ||
| ) | ||
| logger.info( | ||
| "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode | ||
| ) | ||
|
|
||
| if job_config.compile.enable and "model" in job_config.compile.components: | ||
| torch._inductor.config.reorder_for_peak_memory = False | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this flag for? Is the default True? |
||
|
|
||
| match job_config.parallelism.fsdp_reshard_after_forward: | ||
| case "always": | ||
| reshard_after_forward = True | ||
|
|
@@ -126,27 +140,13 @@ def parallelize_llama( | |
| f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." | ||
| ) | ||
|
|
||
| model = data_parallel( | ||
| model, | ||
| parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], | ||
| mode=dp_mode, | ||
| ac_mode=job_config.activation_checkpoint.mode, | ||
| mp_policy=mp_policy, | ||
| reshard_after_forward=reshard_after_forward, | ||
| ) | ||
| logger.info( | ||
| "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode | ||
| ) | ||
|
|
||
| if job_config.compile.enable and "model" in job_config.compile.components: | ||
| torch._inductor.config.reorder_for_peak_memory = False | ||
| backend = ( | ||
| getattr(job_config.compile, "model_backend_override", None) | ||
| or job_config.compile.backend | ||
| ) | ||
| model = torch.compile( | ||
| model, | ||
| backend=get_compile_backend(backend), | ||
| backend=get_compile_backend(backend, reshard_after_forward), | ||
| fullgraph=True, | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,6 @@ | |
| from torch.distributed.tensor._redistribute import redistribute_local_tensor | ||
| from torch.distributed.tensor.placement_types import _StridedShard, Placement | ||
| from torch.utils.checkpoint import ( | ||
| checkpoint, | ||
| CheckpointPolicy, | ||
| create_selective_checkpoint_contexts, | ||
| ) | ||
|
|
@@ -210,7 +209,6 @@ def __init__( | |
| mode, | ||
| regional_ac, | ||
| mp_policy, | ||
| reshard_after_forward, | ||
| reduction_divide_factor, | ||
| ): | ||
| super().__init__() | ||
|
|
@@ -229,7 +227,6 @@ def __init__( | |
| mp_policy = mp_policy or MixedPrecisionPolicy() | ||
| self.param_dtype = mp_policy.param_dtype | ||
| self.reduce_dtype = mp_policy.reduce_dtype | ||
| self.reshard_after_forward = reshard_after_forward | ||
|
|
||
| def replicate_compute(self, x: DTensor) -> torch.Tensor: | ||
| # data parallel runtime replicate parameters and do local compute | ||
|
|
@@ -292,21 +289,7 @@ def forward(self, x: DTensor) -> torch.Tensor: | |
| if not _active_parametrization: | ||
| return x | ||
|
|
||
| if ( | ||
| self.regional_ac | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
| and self.mode in ("fully_shard", "hybrid_shard") | ||
| and self.reshard_after_forward | ||
| ): | ||
| # apply checkpointing to implement reshard_after_forward | ||
| output = checkpoint( | ||
| self.replicate_compute, | ||
| x, | ||
| use_reentrant=False, | ||
| context_fn=fsdp_policy, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
| ) | ||
| else: | ||
| output = self.replicate_compute(x) | ||
|
|
||
| output = self.replicate_compute(x) | ||
| return output | ||
|
|
||
|
|
||
|
|
@@ -316,7 +299,6 @@ def data_parallel( | |
| mode: str = "replicate", | ||
| ac_mode: str = "none", | ||
| mp_policy: MixedPrecisionPolicy | None = None, | ||
| reshard_after_forward: bool = True, | ||
| shard_dim: int = 0, | ||
| reduction_divide_factor: float | None = None, | ||
| ): | ||
|
|
@@ -381,7 +363,6 @@ def data_parallel( | |
| mode, | ||
| regional_ac, | ||
| mp_policy=mp_policy, | ||
| reshard_after_forward=reshard_after_forward, | ||
| reduction_divide_factor=reduction_divide_factor, | ||
| ), | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is for fsdp_reshard_after_forward, not AC in general, so should be renamed.