Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions torchtitan/experiments/simple_fsdp/activation_checkpoint.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is for fsdp_reshard_after_forward, not AC in general, so should be renamed.

Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.utils.checkpoint import CheckpointPolicy


def is_graph_input(node: torch.fx.Node) -> bool:
return node.op == "placeholder"


def is_wait_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.wait_tensor.default
)


def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
)


def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
"""
Returns True if the node is a wait_tensor node that is the result of an all_gather
that can be arbitrarily prefetched, i.e., if all its recursive inputs are
single-input operators that leads to a graph input.
"""
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
n: torch.fx.Node = node.all_input_nodes[0]
while len(n.all_input_nodes) == 1:
if is_graph_input(n.all_input_nodes[0]):
return True
n = n.all_input_nodes[0]
return False


def annotate_fsdp_all_gather(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the _to_copy for mixed precision included?

gm: torch.fx.GraphModule, reshard_after_forward: bool
) -> None:
"""
Force recompute all_gather nodes from simple fsdp in the graph.
This pass should be added in torch._inductor.config.joint_custom_post_pass
"""
graph = gm.graph

def force_recompute_node(node):
if reshard_after_forward:
node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
else:
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
Comment on lines +58 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it work with full AC? IIUC these flags are for SAC API.

# ac_graph_id is used in the partitioner to decide
# if two nodes which have AC applied come from a different
# AC regions. This is needed because nodes in the boundary
# of two AC regions are marked as MUST_SAVE. In our case
# we just add a large value of ac_graph_id so that
# all nodes we tag for recomputation do indeed get recomputed
# and are not influenced by other nodes in the graph with
# nearby ac_graph_id values
node.meta["ac_graph_id"] = 0

# Make all-gather nodes (and related nodes) recomputable, to circumvent
# https://github.com/pytorch/pytorch/issues/136433
for node in graph.nodes:
if is_wait_tensor_from_fsdp(node):
ag_node = node.args[0]
force_recompute_node(ag_node) # all_gather
force_recompute_node(node) # wait_tensor
# Force-recompute slice that comes after wait
for user in node.users:
if (
user.op == "call_function"
and user.target == torch.ops.aten.slice.Tensor
):
force_recompute_node(user)
# Force-recompute potential dtype casts from all_gather
if (
ag_node.all_input_nodes[0].op == "call_function"
and ag_node.args[0].target
== torch.ops.prims.convert_element_type.default
):
force_recompute_node(ag_node.all_input_nodes[0])

return gm
33 changes: 27 additions & 6 deletions torchtitan/experiments/simple_fsdp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Union
from typing import Any

import torch
import torch._functorch.config as functorch_config

from .activation_checkpoint import annotate_fsdp_all_gather

def get_compile_backend(backend_name: str) -> Union[str, callable]:

def get_compile_backend(backend_name: str, reshard_after_forward: bool) -> callable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since it's not related to replicate part of the API

Suggested change
def get_compile_backend(backend_name: str, reshard_after_forward: bool) -> callable:
def get_compile_backend(backend_name: str, fsdp_reshard_after_forward: bool) -> callable:

# return the compile backends used in SimpleFSDP training
# Step1: check if backend_name is inside available torch.compile backends
# Step2: check if the backend_name has been registered as a customized backend
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
if backend_name in available_torch_backend:
return backend_name

if backend_name == "aot_eager_autobucketing":
if backend_name in available_torch_backend:
backend = torch._dynamo.lookup_backend(backend_name)
elif backend_name == "aot_eager_autobucketing":
# Perform auto optimization in aten fx-level and execute code in aot_eager backend
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
Expand Down Expand Up @@ -46,4 +49,22 @@ def aten_autobucketing_reordering_pass(
else:
raise AssertionError(f"Unsupported customized backend: {backend_name}")

return backend
def joint_ac_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
# this pass implements simplefsdp's reshard_after_forward behavior
# when reshard_after_forward set to True, it will annotate simple_fsdp AG
# to CheckpointPolicy.MUST_RECOMPUTE.
# when reshard_after_forward set to False, it will annotate simple_fsdp AG
# to CheckpointPolicy.MUST_SAVE.
gm = annotate_fsdp_all_gather(gm, reshard_after_forward)
gm.recompile()
return gm

def simple_fsdp_custom_pass(*args, **kwargs):
# the ac pass has to operate in a joint graph before partitioner for ac
# annotation to take into effect.
with functorch_config.patch("joint_custom_pass", joint_ac_pass):
return backend(*args, **kwargs)

return simple_fsdp_custom_pass
47 changes: 28 additions & 19 deletions torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@

from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims

from torchtitan.distributed.activation_checkpoint import apply_ac
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.models.deepseek_v3.infra.parallelize import (
apply_ac,
apply_moe_ep_tp,
apply_non_moe_tp,
)
from torchtitan.tools.logging import logger

from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
from ..backend import get_compile_backend

from ..simple_fsdp import data_parallel, MixedPrecisionPolicy

# Adapted from llama4/infra/parallelize.py
def parallelize_deepseekv3(
Expand Down Expand Up @@ -91,20 +93,6 @@ def parallelize_deepseekv3(
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)

match job_config.parallelism.fsdp_reshard_after_forward:
case "always":
reshard_after_forward = True
case "never":
reshard_after_forward = False
case "default":
# For PP, by default do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = not parallel_dims.pp_enabled
case _:
raise ValueError(
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
)

# apply data parallel
dp_mesh: DeviceMesh | None = None
if (
Expand Down Expand Up @@ -157,7 +145,6 @@ def parallelize_deepseekv3(
dp_mode,
ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
reshard_after_forward=reshard_after_forward,
shard_dim=experts_shard_dim,
reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
)
Expand All @@ -168,7 +155,6 @@ def parallelize_deepseekv3(
dp_mode,
ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
reshard_after_forward=reshard_after_forward,
)

logger.info(
Expand All @@ -178,6 +164,29 @@ def parallelize_deepseekv3(
if job_config.compile.enable:
torch._inductor.config.reorder_for_peak_memory = False
torch._dynamo.config.capture_scalar_outputs = True
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)

match job_config.parallelism.fsdp_reshard_after_forward:
case "always":
reshard_after_forward = True
case "never":
reshard_after_forward = False
case "default":
# For PP, by default do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = not parallel_dims.pp_enabled
case _:
raise ValueError(
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
)

backend = (
getattr(job_config.compile, "model_backend_override", None)
or job_config.compile.backend
)
model = torch.compile(
model,
backend=get_compile_backend(backend, reshard_after_forward),
fullgraph=True,
)

return model
30 changes: 15 additions & 15 deletions torchtitan/experiments/simple_fsdp/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ def parallelize_llama(
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)

model = data_parallel(
model,
parallel_dims.world_mesh[tuple(dp_mesh_dim_names)],
mode=dp_mode,
ac_mode=job_config.activation_checkpoint.mode,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove ac_mode

mp_policy=mp_policy,
)
logger.info(
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode
)

if job_config.compile.enable and "model" in job_config.compile.components:
torch._inductor.config.reorder_for_peak_memory = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this flag for? Is the default True?


match job_config.parallelism.fsdp_reshard_after_forward:
case "always":
reshard_after_forward = True
Expand All @@ -126,27 +140,13 @@ def parallelize_llama(
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
)

model = data_parallel(
model,
parallel_dims.world_mesh[tuple(dp_mesh_dim_names)],
mode=dp_mode,
ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
reshard_after_forward=reshard_after_forward,
)
logger.info(
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode
)

if job_config.compile.enable and "model" in job_config.compile.components:
torch._inductor.config.reorder_for_peak_memory = False
backend = (
getattr(job_config.compile, "model_backend_override", None)
or job_config.compile.backend
)
model = torch.compile(
model,
backend=get_compile_backend(backend),
backend=get_compile_backend(backend, reshard_after_forward),
fullgraph=True,
)

Expand Down
21 changes: 1 addition & 20 deletions torchtitan/experiments/simple_fsdp/simple_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.placement_types import _StridedShard, Placement
from torch.utils.checkpoint import (
checkpoint,
CheckpointPolicy,
create_selective_checkpoint_contexts,
)
Expand Down Expand Up @@ -210,7 +209,6 @@ def __init__(
mode,
regional_ac,
mp_policy,
reshard_after_forward,
reduction_divide_factor,
):
super().__init__()
Expand All @@ -229,7 +227,6 @@ def __init__(
mp_policy = mp_policy or MixedPrecisionPolicy()
self.param_dtype = mp_policy.param_dtype
self.reduce_dtype = mp_policy.reduce_dtype
self.reshard_after_forward = reshard_after_forward

def replicate_compute(self, x: DTensor) -> torch.Tensor:
# data parallel runtime replicate parameters and do local compute
Expand Down Expand Up @@ -292,21 +289,7 @@ def forward(self, x: DTensor) -> torch.Tensor:
if not _active_parametrization:
return x

if (
self.regional_ac
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove regional_ac in this file

and self.mode in ("fully_shard", "hybrid_shard")
and self.reshard_after_forward
):
# apply checkpointing to implement reshard_after_forward
output = checkpoint(
self.replicate_compute,
x,
use_reentrant=False,
context_fn=fsdp_policy,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove fsdp_policy in this file

)
else:
output = self.replicate_compute(x)

output = self.replicate_compute(x)
return output


Expand All @@ -316,7 +299,6 @@ def data_parallel(
mode: str = "replicate",
ac_mode: str = "none",
mp_policy: MixedPrecisionPolicy | None = None,
reshard_after_forward: bool = True,
shard_dim: int = 0,
reduction_divide_factor: float | None = None,
):
Expand Down Expand Up @@ -381,7 +363,6 @@ def data_parallel(
mode,
regional_ac,
mp_policy=mp_policy,
reshard_after_forward=reshard_after_forward,
reduction_divide_factor=reduction_divide_factor,
),
)
Expand Down