Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 92 additions & 20 deletions torchtitan/components/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from collections.abc import Callable
from contextlib import AbstractContextManager
from typing import TypeAlias
from typing import Any, TypeAlias

import torch
import torch.nn as nn
Expand All @@ -17,14 +17,13 @@
from torchtitan.components.tokenizer import BaseTokenizer
from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.distributed.context_parallel import prepare_context_parallel_input
from torchtitan.hf_datasets.text_datasets import build_text_validation_dataloader
from torchtitan.protocols.model import BaseModelArgs
from torchtitan.tools import utils
from torchtitan.tools.logging import logger

ValidationContext: TypeAlias = Callable[
[AbstractContextManager[None] | None],
AbstractContextManager[None],
]
ValidationContext: TypeAlias = Callable[[], AbstractContextManager[None]]


class BaseValidator:
Expand Down Expand Up @@ -57,6 +56,7 @@ def __init__(
dp_world_size: int,
dp_rank: int,
tokenizer: BaseTokenizer,
model_args: BaseModelArgs,
parallel_dims: ParallelDims,
loss_fn: LossFunction,
validation_context: ValidationContext,
Expand All @@ -67,6 +67,8 @@ def __init__(
pp_has_last_stage: bool | None = None,
):
self.job_config = job_config
self.tokenizer = tokenizer
self.model_args = model_args
self.parallel_dims = parallel_dims
self.loss_fn = loss_fn
self.validation_dataloader = build_text_validation_dataloader(
Expand All @@ -89,6 +91,71 @@ def __init__(
"unequal sample counts across ranks when dataset is exhausted."
)

def post_dataloading_process(
self,
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we should consolidate this into input_dict["labels"]

device: torch.device,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just use input_dict["inputs"].device?

model_parts: list[nn.Module],
) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]:
"""
Post-processing hook after data loading and before model forward pass.

This method processes the raw data from the dataloader and prepares it for
the model's forward pass. It separates the main input tensor from auxiliary
inputs and constructs additional keyword arguments (e.g., attention masks).

Args:
input_dict: Dictionary containing tensors from the dataloader. Must
contain an "input" key with the main input tensor. May contain
additional keys for auxiliary inputs (e.g., position ids).
labels: Target labels for the batch.
device: Device to use for creating new tensors (e.g., positions).
model_parts: List of model parts for accessing model methods.

Returns:
A tuple of (inputs, labels, extra_inputs, extra_kwargs) where:
- inputs: Main input tensor extracted from input_dict["input"].
- labels: Target labels (potentially modified by CP sharding).
- extra_inputs: Dict of auxiliary input tensors (all keys except
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can consolidate this into input_dict as well?

"input" from input_dict). These are passed to the model forward
but are NOT forwarded across pipeline parallel stages.
- extra_kwargs: Dict of additional keyword arguments for model forward.
These ARE forwarded across pipeline parallel stages. Contains
attention_masks if flex attention is enabled.

Note:
The distinction between extra_inputs and extra_kwargs is important for
pipeline parallelism: extra_kwargs are forwarded to all pipeline stages,
while extra_inputs are only available to the first stage.
"""
inputs = input_dict["input"]
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
# For arguments, like attention_masks, we have to put them in a separate
# dict as extra_inputs are not forwarded to other stages in PP, but
# extra_kwargs are.
extra_kwargs: dict[str, Any] = {}

attn_type = getattr(self.model_args, "attn_type", "sdpa")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we introduce model_args to validator because of this line.

This condition is checked again in get_attention_masks. Can we remove it here? We can return None instead of throw in get_attention_masks when attn_type is sdpa.

The argument is that validator is not supposed to know ModelArgs details.

if attn_type in ["flex", "varlen"]:
# pyrefly: ignore [not-callable]
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
input_batch=inputs,
tokenizer=self.tokenizer,
extra_inputs=extra_inputs,
)

if self.parallel_dims.cp_enabled:
inputs, labels, extra_kwargs = prepare_context_parallel_input(
inputs,
labels,
extra_kwargs,
self.parallel_dims.world_mesh["cp"],
device,
)

return inputs, labels, extra_inputs, extra_kwargs

@torch.no_grad()
# pyrefly: ignore [bad-override]
def validate(
Expand Down Expand Up @@ -117,38 +184,39 @@ def validate(
self.metrics_processor.ntokens_since_last_log += labels.numel()
for k, v in input_dict.items():
input_dict[k] = v.to(device_type)
inputs = input_dict["input"]
labels = labels.to(device_type)

optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=parallel_dims.world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
)
if parallel_dims.cp_enabled
else None
# Create device object for post_dataloading_process
device = torch.device(device_type)

# Process data (extract inputs, handle attention masks, CP sharding)
inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
input_dict, labels, device, model_parts
)

if parallel_dims.pp_enabled:
assert self.pp_schedule is not None
assert self.pp_has_first_stage is not None
assert self.pp_has_last_stage is not None
# Pipeline Parallel forward inside eval() call
with self.validation_context(optional_context_parallel_ctx):
with self.validation_context():
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
)
if self.pp_has_first_stage:
self.pp_schedule.eval(
inputs,
**extra_inputs,
**extra_kwargs,
target=targets,
losses=losses,
)
else:
self.pp_schedule.eval(target=targets, losses=losses)
self.pp_schedule.eval(
**extra_kwargs,
target=targets,
losses=losses,
)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
Expand All @@ -161,10 +229,12 @@ def validate(
else torch.tensor([-1.0], device=device_type)
)
else:
with self.validation_context(optional_context_parallel_ctx):
with self.validation_context():
assert len(model_parts) == 1
with self.maybe_enable_amp:
predictions = model_parts[0](inputs)
predictions = model_parts[0](
inputs, **extra_inputs, **extra_kwargs
)
loss = self.loss_fn(predictions, labels)

accumulated_losses.append(loss.detach())
Expand Down Expand Up @@ -193,6 +263,7 @@ def build_validator(
dp_world_size: int,
dp_rank: int,
tokenizer: BaseTokenizer,
model_args: BaseModelArgs,
parallel_dims: ParallelDims,
loss_fn: LossFunction,
validation_context: ValidationContext,
Expand All @@ -208,6 +279,7 @@ def build_validator(
dp_world_size=dp_world_size,
dp_rank=dp_rank,
tokenizer=tokenizer,
model_args=model_args,
parallel_dims=parallel_dims,
loss_fn=loss_fn,
validation_context=validation_context,
Expand Down
116 changes: 116 additions & 0 deletions torchtitan/distributed/context_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any

import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.experimental._attention import (
_ContextParallel,
_enable_context_parallel_dispatcher,
)
from torch.distributed.tensor.parallel import parallelize_module

from torchtitan.distributed import utils as dist_utils
from torchtitan.tools.logging import logger


def apply_cp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name sounds too generic for what it's assuming, e.g. you needed a different one for flux.

Maybe apply_cp_to_transformer_blocks and send in model.layers.values()?

model: nn.Module,
cp_mesh: DeviceMesh,
use_flex_attn: bool,
) -> None:
"""
Apply context parallelism to the model.
Context Parallelism (CP) splits the sequence dimension across devices to enable
training with longer sequences. This function applies CP to the attention modules
of all transformer blocks in the model.
Args:
model: The transformer model with layers containing attention modules
cp_mesh: Device mesh for context parallel dimension
use_flex_attn: Whether the model uses FlexAttention (True) or SDPA (False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: We don't consider varlen + CP here? why is that?

Note:
- For FlexAttention: CP plan uses FLEX attention type
- For SDPA: Enables CP dispatcher and uses SDPA attention type
- Applies to transformer_block.attention.inner_attention for each layer
"""
# Apply context parallelism to every transformer block
# TODO: make seq_sim configurable once the implementation doesn't assume 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# TODO: make seq_sim configurable once the implementation doesn't assume 2
# TODO: make seq_dim configurable once the implementation doesn't assume 2

# internally.
if use_flex_attn:
cp_plan = _ContextParallel(
seq_dim=2, attention_type=_ContextParallel.AttentionType.FLEX
)
else:
# This is currently required as DTensor dispatcher is not enabled to
# dispatch SDPA to CP implementation. We don't disable the CP
# dispatching in TorchTitan as it is not needed. But there is a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is a little bit confusing to me - DTensor dispatcher is not enabled to dispatch SDPA to CP implementation, so we explicitly enable it by calling _enable_context_parallel_dispatcher. But why "we don't disable the CP dispatching in torchtitan"?

# corresponding API, _disable_context_parallel_dispatcher to do
# that if users have this use case.
_enable_context_parallel_dispatcher()
cp_plan = _ContextParallel(
seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA
)

# pyrefly: ignore [not-callable]
for transformer_block in model.layers.values():
parallelize_module(
# pyrefly: ignore [missing-attribute]
module=transformer_block.attention.inner_attention,
device_mesh=cp_mesh,
parallelize_plan=cp_plan,
)

logger.info("Applied Context Parallel to the model")


def prepare_context_parallel_input(
inputs: torch.Tensor,
labels: torch.Tensor,
extra_kwargs: dict[str, Any],
cp_mesh: DeviceMesh,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
"""
Prepare inputs, labels, and attention masks for Context Parallel forward pass.
This function prepares tensors for Context Parallel by:
1. Creating position indices based on input sequence length
2. Sharding inputs, labels, and positions across the CP mesh
3. Sharding attention masks if present
Args:
inputs: Input tensor of shape [batch_size, seq_len]
labels: Label tensor of shape [batch_size, seq_len]
extra_kwargs: Dictionary that may contain 'attention_masks' to be sharded
cp_mesh: Device mesh for context parallel dimension
device: Device to create position tensor on
Returns:
Tuple of (sharded_inputs, sharded_labels, updated_extra_kwargs) where:
- sharded_inputs: Inputs sharded along sequence dimension
- sharded_labels: Labels sharded along sequence dimension
- updated_extra_kwargs: Dict with sharded 'positions' and optionally
sharded 'attention_masks'
"""
attention_masks = extra_kwargs.get("attention_masks", None)
positions = torch.arange(
0, inputs.shape[1], dtype=torch.int32, device=device
).expand(inputs.shape)
(inputs, labels, positions), attention_masks = dist_utils.cp_shard(
cp_mesh,
(inputs, labels, positions),
attention_masks,
)
extra_kwargs["positions"] = positions
if attention_masks is not None:
extra_kwargs["attention_masks"] = attention_masks

return inputs, labels, extra_kwargs
Loading
Loading