Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 94 additions & 16 deletions torchtitan/components/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Generator
from typing import Any, Generator

import torch
import torch.nn as nn
Expand All @@ -16,6 +16,7 @@
from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.hf_datasets.text_datasets import build_text_validation_dataloader
from torchtitan.protocols.model import BaseModelArgs
from torchtitan.tools import utils
from torchtitan.tools.logging import logger

Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
dp_world_size: int,
dp_rank: int,
tokenizer: BaseTokenizer,
model_args: BaseModelArgs,
parallel_dims: ParallelDims,
loss_fn: LossFunction,
validation_context: Generator[None, None, None],
Expand All @@ -60,6 +62,8 @@ def __init__(
pp_has_last_stage: bool | None = None,
):
self.job_config = job_config
self.tokenizer = tokenizer
self.model_args = model_args
self.parallel_dims = parallel_dims
self.loss_fn = loss_fn
self.validation_dataloader = build_text_validation_dataloader(
Expand All @@ -82,6 +86,75 @@ def __init__(
"unequal sample counts across ranks when dataset is exhausted."
)

def post_dataloading_process(
self,
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we should consolidate this into input_dict["labels"]

device: torch.device,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just use input_dict["inputs"].device?

model_parts: list[nn.Module],
) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]:
"""
Post-processing hook after data loading and before model forward pass.

This method processes the raw data from the dataloader and prepares it for
the model's forward pass. It separates the main input tensor from auxiliary
inputs and constructs additional keyword arguments (e.g., attention masks).

Args:
input_dict: Dictionary containing tensors from the dataloader. Must
contain an "input" key with the main input tensor. May contain
additional keys for auxiliary inputs (e.g., position ids).
labels: Target labels for the batch.
device: Device to use for creating new tensors (e.g., positions).
model_parts: List of model parts for accessing model methods.

Returns:
A tuple of (inputs, labels, extra_inputs, extra_kwargs) where:
- inputs: Main input tensor extracted from input_dict["input"].
- labels: Target labels (potentially modified by CP sharding).
- extra_inputs: Dict of auxiliary input tensors (all keys except
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can consolidate this into input_dict as well?

"input" from input_dict). These are passed to the model forward
but are NOT forwarded across pipeline parallel stages.
- extra_kwargs: Dict of additional keyword arguments for model forward.
These ARE forwarded across pipeline parallel stages. Contains
attention_masks if flex attention is enabled.

Note:
The distinction between extra_inputs and extra_kwargs is important for
pipeline parallelism: extra_kwargs are forwarded to all pipeline stages,
while extra_inputs are only available to the first stage.
"""
inputs = input_dict["input"]
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
# For arguments, like attention_masks, we have to put them in a separate
# dict as extra_inputs are not forwarded to other stages in PP, but
# extra_kwargs are.
extra_kwargs: dict[str, Any] = {}

attn_type = getattr(self.model_args, "attn_type", "sdpa")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we introduce model_args to validator because of this line.

This condition is checked again in get_attention_masks. Can we remove it here? We can return None instead of throw in get_attention_masks when attn_type is sdpa.

The argument is that validator is not supposed to know ModelArgs details.

if attn_type in ["flex", "varlen"]:
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
input_batch=inputs,
tokenizer=self.tokenizer,
extra_inputs=extra_inputs,
)

if self.parallel_dims.cp_enabled:
attention_masks = extra_kwargs.get("attention_masks", None)
positions = torch.arange(
0, inputs.shape[1], dtype=torch.int32, device=device
).expand(inputs.shape)
(inputs, labels, positions), attention_masks = dist_utils.cp_shard(
self.parallel_dims.world_mesh["cp"],
(inputs, labels, positions),
attention_masks,
)
extra_kwargs["positions"] = positions
if attention_masks is not None:
extra_kwargs["attention_masks"] = attention_masks

return inputs, labels, extra_inputs, extra_kwargs

@torch.no_grad()
def validate(
self,
Expand All @@ -108,38 +181,39 @@ def validate(
self.metrics_processor.ntokens_since_last_log += labels.numel()
for k, v in input_dict.items():
input_dict[k] = v.to(device_type)
inputs = input_dict["input"]
labels = labels.to(device_type)

optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=parallel_dims.world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
)
if parallel_dims.cp_enabled
else None
# Create device object for post_dataloading_process
device = torch.device(device_type)

# Process data (extract inputs, handle attention masks, CP sharding)
inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
input_dict, labels, device, model_parts
)

if parallel_dims.pp_enabled:
assert self.pp_schedule is not None
assert self.pp_has_first_stage is not None
assert self.pp_has_last_stage is not None
# Pipeline Parallel forward inside eval() call
with self.validation_context(optional_context_parallel_ctx):
with self.validation_context():
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
)
if self.pp_has_first_stage:
self.pp_schedule.eval(
inputs,
**extra_inputs,
**extra_kwargs,
target=targets,
losses=losses,
)
else:
self.pp_schedule.eval(target=targets, losses=losses)
self.pp_schedule.eval(
**extra_kwargs,
target=targets,
losses=losses,
)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
Expand All @@ -152,10 +226,12 @@ def validate(
else torch.tensor([-1.0], device=device_type)
)
else:
with self.validation_context(optional_context_parallel_ctx):
with self.validation_context():
assert len(model_parts) == 1
with self.maybe_enable_amp:
predictions = model_parts[0](inputs)
predictions = model_parts[0](
inputs, **extra_inputs, **extra_kwargs
)
loss = self.loss_fn(predictions, labels)

accumulated_losses.append(loss.detach())
Expand Down Expand Up @@ -184,6 +260,7 @@ def build_validator(
dp_world_size: int,
dp_rank: int,
tokenizer: BaseTokenizer,
model_args: BaseModelArgs,
parallel_dims: ParallelDims,
loss_fn: LossFunction,
validation_context: Generator[None, None, None],
Expand All @@ -199,6 +276,7 @@ def build_validator(
dp_world_size=dp_world_size,
dp_rank=dp_rank,
tokenizer=tokenizer,
model_args=model_args,
parallel_dims=parallel_dims,
loss_fn=loss_fn,
validation_context=validation_context,
Expand Down
40 changes: 36 additions & 4 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from torchtitan.config import Comm as CommConfig, Debug as DebugConfig, TORCH_DTYPE_MAP
from torchtitan.distributed.parallel_dims import ParallelDims
from torchtitan.protocols.model import AttentionMasksType
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import device_module, device_type

Expand Down Expand Up @@ -222,14 +223,11 @@ def create_context_parallel_ctx(

def get_train_context(enable_loss_parallel: bool) -> Generator[None, None, None]:
@contextlib.contextmanager
def context(cp_context: Generator[None, None, None] | None = None):
def context():
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())

if cp_context:
stack.enter_context(cp_context)

yield

return context
Expand Down Expand Up @@ -515,3 +513,37 @@ def _clip_grad_norm_with_ep(
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach)

return total_norm


def cp_shard(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this function also go to context_parallel.py? maybe we can consolidate with get_context_parallel_inputs since both are not long.

cp_mesh: DeviceMesh,
inputs: torch.Tensor,
attention_masks: AttentionMasksType | None,
):
from torch.distributed.tensor.experimental._attention import (
_context_parallel_shard,
_HeadTailLoadBalancer,
)
from torch.nn.attention.flex_attention import BlockMask

INPUT_SEQ_DIM = 1
seq_len = inputs[0].size(INPUT_SEQ_DIM)
cp_world_size = cp_mesh.size(0)
if isinstance(attention_masks, BlockMask):
raise ValueError(
"FlexAttention CP is not supported yet. Will come in the next PR."
)
else:
# For SDPA, we use the _HeadTailLoadBalancer.
load_balancer = _HeadTailLoadBalancer(
seq_len, cp_world_size, cp_mesh.device_type
)

inputs = _context_parallel_shard(
mesh=cp_mesh,
buffers=inputs,
seq_dims=tuple(1 for _ in inputs),
load_balancer=load_balancer,
)

return inputs, attention_masks
60 changes: 39 additions & 21 deletions torchtitan/experiments/forge/example_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,42 +152,60 @@ def batch_generator(

yield input_dict, labels

def forward_backward_step(
def post_dataloading_process(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims

) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]:
inputs = input_dict["input"]
extra_kwargs = {}

if getattr(self.model_args, "attn_type", "sdpa") == "flex":
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
# For arguments, like attention_masks, we have to put them in a separate
# dict as extra_inputs are not forwarded to other stages in PP, but
# extra_kwargs are.
extra_kwargs: dict[str, Any] = {}

attn_type = getattr(self.model_args, "attn_type", "sdpa")
if attn_type in ["flex", "varlen"]:
extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
input_batch=inputs,
tokenizer=self.tokenizer,
extra_inputs=extra_inputs,
)

optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=parallel_dims.world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
if self.parallel_dims.cp_enabled:
attention_masks = extra_kwargs.get("attention_masks", None)
positions = torch.arange(
0, inputs.shape[1], dtype=torch.int32, device=self.device
).expand(inputs.shape)
(inputs, labels, positions), attention_masks = dist_utils.cp_shard(
self.parallel_dims.world_mesh["cp"],
(inputs, labels, positions),
attention_masks,
)
if parallel_dims.cp_enabled
else None
extra_kwargs["positions"] = positions
if attention_masks is not None:
extra_kwargs["attention_masks"] = attention_masks

return inputs, labels, extra_inputs, extra_kwargs

def forward_backward_step(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims

inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
input_dict, labels
)

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
with self.train_context(optional_context_parallel_ctx):
with self.train_context():
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
)
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs,
**extra_inputs,
**extra_kwargs,
target=targets,
losses=losses,
Expand All @@ -211,10 +229,10 @@ def forward_backward_step(
)
else:
# Non-PP forward / backward
with self.train_context(optional_context_parallel_ctx):
with self.train_context():
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs, **extra_kwargs)
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
loss = self.loss_fn(pred, labels)
# need to free to before bwd to avoid peaking memory
del pred
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def parallelize_deepseekv3(
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=False,
cp_enabled=parallel_dims.cp_enabled,
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

Expand Down
14 changes: 10 additions & 4 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def parallelize_deepseekv3(
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=False,
cp_enabled=parallel_dims.cp_enabled,
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

Expand Down Expand Up @@ -183,6 +184,7 @@ def apply_non_moe_tp(
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8_tensorwise_tp: bool,
cp_enabled: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
Expand Down Expand Up @@ -221,14 +223,18 @@ def apply_non_moe_tp(
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
# by folding (and unfolding) the batch dimension and the sequence dimension.
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
positions_sharding = Replicate() if cp_enabled else None
for transformer_block in model.layers.values():
layer_plan = {
"attention_norm": SequenceParallel(),
# NOTE: when the fourth argument (positions) is not None, its input layout
# and desired input layout should be Replicate()
"attention": prepare_module_input(
input_layouts=(Shard(1), Replicate(), None, None),
desired_input_layouts=(Replicate(), Replicate(), None, None),
input_layouts=(Shard(1), Replicate(), None, positions_sharding),
desired_input_layouts=(
Replicate(),
Replicate(),
None,
positions_sharding,
),
),
# NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor
# so that the intermedidate results k is generated as a DTensor and its gradient is
Expand Down
Loading
Loading