diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 4673807347..4b2c79e788 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -6,7 +6,7 @@ from collections.abc import Callable from contextlib import AbstractContextManager -from typing import TypeAlias +from typing import Any, TypeAlias import torch import torch.nn as nn @@ -17,14 +17,13 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.distributed.context_parallel import prepare_context_parallel_input from torchtitan.hf_datasets.text_datasets import build_text_validation_dataloader +from torchtitan.protocols.model import BaseModelArgs from torchtitan.tools import utils from torchtitan.tools.logging import logger -ValidationContext: TypeAlias = Callable[ - [AbstractContextManager[None] | None], - AbstractContextManager[None], -] +ValidationContext: TypeAlias = Callable[[], AbstractContextManager[None]] class BaseValidator: @@ -57,6 +56,7 @@ def __init__( dp_world_size: int, dp_rank: int, tokenizer: BaseTokenizer, + model_args: BaseModelArgs, parallel_dims: ParallelDims, loss_fn: LossFunction, validation_context: ValidationContext, @@ -67,6 +67,8 @@ def __init__( pp_has_last_stage: bool | None = None, ): self.job_config = job_config + self.tokenizer = tokenizer + self.model_args = model_args self.parallel_dims = parallel_dims self.loss_fn = loss_fn self.validation_dataloader = build_text_validation_dataloader( @@ -89,6 +91,71 @@ def __init__( "unequal sample counts across ranks when dataset is exhausted." ) + def post_dataloading_process( + self, + input_dict: dict[str, torch.Tensor], + labels: torch.Tensor, + device: torch.device, + model_parts: list[nn.Module], + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: + """ + Post-processing hook after data loading and before model forward pass. + + This method processes the raw data from the dataloader and prepares it for + the model's forward pass. It separates the main input tensor from auxiliary + inputs and constructs additional keyword arguments (e.g., attention masks). + + Args: + input_dict: Dictionary containing tensors from the dataloader. Must + contain an "input" key with the main input tensor. May contain + additional keys for auxiliary inputs (e.g., position ids). + labels: Target labels for the batch. + device: Device to use for creating new tensors (e.g., positions). + model_parts: List of model parts for accessing model methods. + + Returns: + A tuple of (inputs, labels, extra_inputs, extra_kwargs) where: + - inputs: Main input tensor extracted from input_dict["input"]. + - labels: Target labels (potentially modified by CP sharding). + - extra_inputs: Dict of auxiliary input tensors (all keys except + "input" from input_dict). These are passed to the model forward + but are NOT forwarded across pipeline parallel stages. + - extra_kwargs: Dict of additional keyword arguments for model forward. + These ARE forwarded across pipeline parallel stages. Contains + attention_masks if flex attention is enabled. + + Note: + The distinction between extra_inputs and extra_kwargs is important for + pipeline parallelism: extra_kwargs are forwarded to all pipeline stages, + while extra_inputs are only available to the first stage. + """ + inputs = input_dict["input"] + extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} + # For arguments, like attention_masks, we have to put them in a separate + # dict as extra_inputs are not forwarded to other stages in PP, but + # extra_kwargs are. + extra_kwargs: dict[str, Any] = {} + + attn_type = getattr(self.model_args, "attn_type", "sdpa") + if attn_type in ["flex", "varlen"]: + # pyrefly: ignore [not-callable] + extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( + input_batch=inputs, + tokenizer=self.tokenizer, + extra_inputs=extra_inputs, + ) + + if self.parallel_dims.cp_enabled: + inputs, labels, extra_kwargs = prepare_context_parallel_input( + inputs, + labels, + extra_kwargs, + self.parallel_dims.world_mesh["cp"], + device, + ) + + return inputs, labels, extra_inputs, extra_kwargs + @torch.no_grad() # pyrefly: ignore [bad-override] def validate( @@ -117,19 +184,14 @@ def validate( self.metrics_processor.ntokens_since_last_log += labels.numel() for k, v in input_dict.items(): input_dict[k] = v.to(device_type) - inputs = input_dict["input"] labels = labels.to(device_type) - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1, 1] + [0 for _ in model_parts], - cp_no_restore_buffers={inputs, labels}, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, - ) - if parallel_dims.cp_enabled - else None + # Create device object for post_dataloading_process + device = torch.device(device_type) + + # Process data (extract inputs, handle attention masks, CP sharding) + inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( + input_dict, labels, device, model_parts ) if parallel_dims.pp_enabled: @@ -137,18 +199,24 @@ def validate( assert self.pp_has_first_stage is not None assert self.pp_has_last_stage is not None # Pipeline Parallel forward inside eval() call - with self.validation_context(optional_context_parallel_ctx): + with self.validation_context(): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) ) if self.pp_has_first_stage: self.pp_schedule.eval( inputs, + **extra_inputs, + **extra_kwargs, target=targets, losses=losses, ) else: - self.pp_schedule.eval(target=targets, losses=losses) + self.pp_schedule.eval( + **extra_kwargs, + target=targets, + losses=losses, + ) # accumulate losses across pipeline microbatches # TODO: PP+FSDP unexpectedly puts the loss back to the CPU @@ -161,10 +229,12 @@ def validate( else torch.tensor([-1.0], device=device_type) ) else: - with self.validation_context(optional_context_parallel_ctx): + with self.validation_context(): assert len(model_parts) == 1 with self.maybe_enable_amp: - predictions = model_parts[0](inputs) + predictions = model_parts[0]( + inputs, **extra_inputs, **extra_kwargs + ) loss = self.loss_fn(predictions, labels) accumulated_losses.append(loss.detach()) @@ -193,6 +263,7 @@ def build_validator( dp_world_size: int, dp_rank: int, tokenizer: BaseTokenizer, + model_args: BaseModelArgs, parallel_dims: ParallelDims, loss_fn: LossFunction, validation_context: ValidationContext, @@ -208,6 +279,7 @@ def build_validator( dp_world_size=dp_world_size, dp_rank=dp_rank, tokenizer=tokenizer, + model_args=model_args, parallel_dims=parallel_dims, loss_fn=loss_fn, validation_context=validation_context, diff --git a/torchtitan/distributed/context_parallel.py b/torchtitan/distributed/context_parallel.py new file mode 100644 index 0000000000..992fcf6c06 --- /dev/null +++ b/torchtitan/distributed/context_parallel.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.experimental._attention import ( + _ContextParallel, + _enable_context_parallel_dispatcher, +) +from torch.distributed.tensor.parallel import parallelize_module + +from torchtitan.distributed import utils as dist_utils +from torchtitan.tools.logging import logger + + +def apply_cp( + model: nn.Module, + cp_mesh: DeviceMesh, + use_flex_attn: bool, +) -> None: + """ + Apply context parallelism to the model. + + Context Parallelism (CP) splits the sequence dimension across devices to enable + training with longer sequences. This function applies CP to the attention modules + of all transformer blocks in the model. + + Args: + model: The transformer model with layers containing attention modules + cp_mesh: Device mesh for context parallel dimension + use_flex_attn: Whether the model uses FlexAttention (True) or SDPA (False) + + Note: + - For FlexAttention: CP plan uses FLEX attention type + - For SDPA: Enables CP dispatcher and uses SDPA attention type + - Applies to transformer_block.attention.inner_attention for each layer + """ + # Apply context parallelism to every transformer block + # TODO: make seq_sim configurable once the implementation doesn't assume 2 + # internally. + if use_flex_attn: + cp_plan = _ContextParallel( + seq_dim=2, attention_type=_ContextParallel.AttentionType.FLEX + ) + else: + # This is currently required as DTensor dispatcher is not enabled to + # dispatch SDPA to CP implementation. We don't disable the CP + # dispatching in TorchTitan as it is not needed. But there is a + # corresponding API, _disable_context_parallel_dispatcher to do + # that if users have this use case. + _enable_context_parallel_dispatcher() + cp_plan = _ContextParallel( + seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA + ) + + # pyrefly: ignore [not-callable] + for transformer_block in model.layers.values(): + parallelize_module( + # pyrefly: ignore [missing-attribute] + module=transformer_block.attention.inner_attention, + device_mesh=cp_mesh, + parallelize_plan=cp_plan, + ) + + logger.info("Applied Context Parallel to the model") + + +def prepare_context_parallel_input( + inputs: torch.Tensor, + labels: torch.Tensor, + extra_kwargs: dict[str, Any], + cp_mesh: DeviceMesh, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + """ + Prepare inputs, labels, and attention masks for Context Parallel forward pass. + + This function prepares tensors for Context Parallel by: + 1. Creating position indices based on input sequence length + 2. Sharding inputs, labels, and positions across the CP mesh + 3. Sharding attention masks if present + + Args: + inputs: Input tensor of shape [batch_size, seq_len] + labels: Label tensor of shape [batch_size, seq_len] + extra_kwargs: Dictionary that may contain 'attention_masks' to be sharded + cp_mesh: Device mesh for context parallel dimension + device: Device to create position tensor on + + Returns: + Tuple of (sharded_inputs, sharded_labels, updated_extra_kwargs) where: + - sharded_inputs: Inputs sharded along sequence dimension + - sharded_labels: Labels sharded along sequence dimension + - updated_extra_kwargs: Dict with sharded 'positions' and optionally + sharded 'attention_masks' + """ + attention_masks = extra_kwargs.get("attention_masks", None) + positions = torch.arange( + 0, inputs.shape[1], dtype=torch.int32, device=device + ).expand(inputs.shape) + (inputs, labels, positions), attention_masks = dist_utils.cp_shard( + cp_mesh, + (inputs, labels, positions), + attention_masks, + ) + extra_kwargs["positions"] = positions + if attention_masks is not None: + extra_kwargs["attention_masks"] = attention_masks + + return inputs, labels, extra_kwargs diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 811e062958..1dfb635078 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -10,7 +10,7 @@ from abc import abstractmethod from collections.abc import Iterable from datetime import timedelta -from typing import Protocol +from typing import cast, Protocol import torch import torch.distributed._functional_collectives as funcol @@ -23,6 +23,7 @@ from torchtitan.config import Comm as CommConfig, Debug as DebugConfig, TORCH_DTYPE_MAP from torchtitan.distributed.parallel_dims import ParallelDims +from torchtitan.protocols.model import AttentionMasksType from torchtitan.tools.logging import logger from torchtitan.tools.utils import device_module, device_type @@ -231,23 +232,17 @@ def create_context_parallel_ctx( class TrainContext(Protocol): @abstractmethod - def __call__( - self, - cp_context: contextlib.AbstractContextManager[None] | None = None, - ) -> contextlib.AbstractContextManager[None]: + def __call__(self) -> contextlib.AbstractContextManager[None]: pass def get_train_context(enable_loss_parallel: bool) -> TrainContext: @contextlib.contextmanager - def context(cp_context: contextlib.AbstractContextManager[None] | None = None): + def context(): with contextlib.ExitStack() as stack: if enable_loss_parallel: stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) - if cp_context: - stack.enter_context(cp_context) - yield return context @@ -538,3 +533,42 @@ def _clip_grad_norm_with_ep( torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach) return total_norm + + +def cp_shard( + cp_mesh: DeviceMesh, + inputs: tuple[torch.Tensor, ...], + attention_masks: AttentionMasksType | None, + disable_load_balancer: bool = False, +): + from torch.distributed.tensor.experimental._attention import ( + _context_parallel_shard, + _HeadTailLoadBalancer, + ) + + INPUT_SEQ_DIM = 1 + seq_len = inputs[0].size(INPUT_SEQ_DIM) + cp_world_size = cp_mesh.size(0) + if attention_masks is not None: + raise ValueError( + "FlexAttention CP is not supported yet. Will come in the next PR." + ) + else: + # For SDPA, we use the _HeadTailLoadBalancer. + load_balancer = ( + None + if disable_load_balancer + else _HeadTailLoadBalancer(seq_len, cp_world_size, cp_mesh.device_type) + ) + + inputs = cast( + tuple[torch.Tensor, ...], + _context_parallel_shard( + mesh=cp_mesh, + buffers=inputs, + seq_dims=tuple(1 for _ in inputs), + load_balancer=load_balancer, + ), + ) + + return inputs, attention_masks diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 66ad151dd0..55eabadd77 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -19,6 +19,7 @@ from torchtitan.components.validate import build_validator from torchtitan.config import JobConfig from torchtitan.distributed import utils as dist_utils +from torchtitan.distributed.context_parallel import prepare_context_parallel_input from torchtitan.hf_datasets.text_datasets import build_text_dataloader from torchtitan.tools import utils from torchtitan.tools.logging import logger @@ -152,42 +153,55 @@ def batch_generator( yield input_dict, labels - def forward_backward_step( + def post_dataloading_process( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor - ) -> torch.Tensor: - model_parts = self.model_parts - parallel_dims = self.parallel_dims - + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: inputs = input_dict["input"] - extra_kwargs = {} - - if getattr(self.model_args, "attn_type", "sdpa") == "flex": - extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( + extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} + # For arguments, like attention_masks, we have to put them in a separate + # dict as extra_inputs are not forwarded to other stages in PP, but + # extra_kwargs are. + extra_kwargs: dict[str, Any] = {} + + attn_type = getattr(self.model_args, "attn_type", "sdpa") + if attn_type in ["flex", "varlen"]: + extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, + extra_inputs=extra_inputs, ) - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1, 1] + [0 for _ in model_parts], - cp_no_restore_buffers={inputs, labels}, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + if self.parallel_dims.cp_enabled: + inputs, labels, extra_kwargs = prepare_context_parallel_input( + inputs, + labels, + extra_kwargs, + self.parallel_dims.world_mesh["cp"], + self.device, ) - if parallel_dims.cp_enabled - else None + + return inputs, labels, extra_inputs, extra_kwargs + + def forward_backward_step( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ) -> torch.Tensor: + model_parts = self.model_parts + parallel_dims = self.parallel_dims + + inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( + input_dict, labels ) if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call - with self.train_context(optional_context_parallel_ctx): + with self.train_context(): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) ) if self.pp_has_first_stage: self.pp_schedule.step( inputs, + **extra_inputs, **extra_kwargs, target=targets, losses=losses, @@ -211,10 +225,10 @@ def forward_backward_step( ) else: # Non-PP forward / backward - with self.train_context(optional_context_parallel_ctx): + with self.train_context(): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, **extra_kwargs) + pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 83e24d7dc1..ed81389f5b 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -90,6 +90,7 @@ def parallelize_deepseekv3( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, + cp_enabled=parallel_dims.cp_enabled, ) maybe_enable_async_tp(job_config, world_mesh["tp"]) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 63fb910376..d00252671a 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -19,6 +19,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp from torchtitan.distributed.dual_pipe_v import get_dual_pipe_v_flag from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.llama3.infra.parallelize import apply_ddp @@ -89,6 +90,7 @@ def parallelize_deepseekv3( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, + cp_enabled=parallel_dims.cp_enabled, ) maybe_enable_async_tp(job_config, world_mesh["tp"]) @@ -110,6 +112,10 @@ def parallelize_deepseekv3( dual_pipe_v=dual_pipe_v, ) + if parallel_dims.cp_enabled: + use_flex_attn = attn_type == "flex" + apply_cp(model, world_mesh["cp"], use_flex_attn) + model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) @@ -165,9 +171,6 @@ def parallelize_deepseekv3( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: @@ -188,6 +191,7 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, + cp_enabled: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -226,15 +230,19 @@ def apply_non_moe_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + positions_sharding = Replicate() if cp_enabled else None # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), - # NOTE: when the fourth argument (positions) is not None, its input layout - # and desired input layout should be Replicate() "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate(), None, None), - desired_input_layouts=(Replicate(), Replicate(), None, None), + input_layouts=(Shard(1), Replicate(), None, positions_sharding), + desired_input_layouts=( + Replicate(), + Replicate(), + None, + positions_sharding, + ), ), # NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor # so that the intermedidate results k is generated as a DTensor and its gradient is diff --git a/torchtitan/models/flux/infra/parallelize.py b/torchtitan/models/flux/infra/parallelize.py index b27fa93a31..f34d2cddb1 100644 --- a/torchtitan/models/flux/infra/parallelize.py +++ b/torchtitan/models/flux/infra/parallelize.py @@ -14,6 +14,11 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.tensor.experimental._attention import ( + _ContextParallel, + _enable_context_parallel_dispatcher, +) +from torch.distributed.tensor.parallel import parallelize_module from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims @@ -28,6 +33,9 @@ def parallelize_flux( if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) + if parallel_dims.cp_enabled: + apply_cp(model, parallel_dims.world_mesh["cp"]) + if parallel_dims.fsdp_enabled: if parallel_dims.dp_replicate_enabled: dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") @@ -47,16 +55,6 @@ def parallelize_flux( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - # The attention in Flux does not use causal mask. - # Currently, load_balance must be disabled in order to support Context Parallelism - # in Pytorch's experimental ring attention module - # https://github.com/pytorch/pytorch/blob/v2.9.0/torch/distributed/tensor/experimental/_attention.py#L395 - from torch.distributed.tensor.experimental._attention import _cp_options - - _cp_options.enable_load_balance = False - logger.info("Applied Context Parallel to the model") - return model @@ -134,6 +132,61 @@ def apply_ac(model: nn.Module, ac_config): logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") +def apply_cp(model: nn.Module, cp_mesh: DeviceMesh) -> None: + """ + Apply context parallelism to the Flux model. + + Args: + model: The Flux model with double_blocks and single_blocks containing + inner attention modules. + cp_mesh: Device mesh for context parallel dimension + + Note: + - Uses SDPA attention type + - Applies to double_block.inner_attention for each DoubleStreamBlock + """ + # Apply context parallelism to every DoubleStreamBlock + # TODO: make seq_sim configurable once the implementation doesn't assume 2 + # internally. + + _enable_context_parallel_dispatcher() + cp_plan = _ContextParallel( + seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA + ) + + # pyrefly: ignore [not-iterable] + for double_block in model.double_blocks: + parallelize_module( + # pyrefly: ignore [missing-attribute] + module=double_block.img_attn.inner_attention, + device_mesh=cp_mesh, + parallelize_plan=cp_plan, + ) + parallelize_module( + # pyrefly: ignore [missing-attribute] + module=double_block.txt_attn.inner_attention, + device_mesh=cp_mesh, + parallelize_plan=cp_plan, + ) + parallelize_module( + # pyrefly: ignore [missing-attribute] + module=double_block.inner_attention, + device_mesh=cp_mesh, + parallelize_plan=cp_plan, + ) + + # pyrefly: ignore [not-iterable] + for single_block in model.sigle_blocks: + parallelize_module( + # pyrefly: ignore [missing-attribute] + module=single_block.inner_attention, + device_mesh=cp_mesh, + parallelize_plan=cp_plan, + ) + + logger.info("Applied Context Parallel to the Flux model") + + def parallelize_encoders( t5_model: nn.Module, clip_model: nn.Module, diff --git a/torchtitan/models/flux/model/layers.py b/torchtitan/models/flux/model/layers.py index 30ba52d3a3..6d0e696dd9 100644 --- a/torchtitan/models/flux/model/layers.py +++ b/torchtitan/models/flux/model/layers.py @@ -13,6 +13,8 @@ from einops import rearrange from torch import nn, Tensor +from torchtitan.models.attention import ScaledDotProductAttentionWrapper + def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 @@ -124,6 +126,7 @@ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.norm = QKNorm(head_dim) self.proj = nn.Linear(dim, dim) + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self): for layer in (self.qkv, self.proj): @@ -136,7 +139,7 @@ def forward(self, x: Tensor, pe: Tensor) -> Tensor: q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) q, k = apply_rope(q, k, pe) - x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = self.inner_attention(q, k, v) x = rearrange(x, "B H L D -> B L (H D)") x = self.proj(x) return x @@ -206,6 +209,8 @@ def __init__( nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) + self.inner_attention = ScaledDotProductAttentionWrapper() + def init_weights(self): # initialize all the nn.Linear submodules for layer in ( @@ -257,7 +262,7 @@ def forward( v = torch.cat((txt_v, img_v), dim=2) q, k = apply_rope(q, k, pe) - attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) + attn = self.inner_attention(q, k, v) attn = rearrange(attn, "B H L D -> B L (H D)") txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] @@ -308,6 +313,7 @@ def __init__( self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False) + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self): for layer in (self.linear1, self.linear2): @@ -329,7 +335,7 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: # compute attention q, k = apply_rope(q, k, pe) - attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) + attn = self.inner_attention(q, k, v) attn = rearrange(attn, "B H L D -> B L (H D)") # compute activation in mlp stream, cat again and run second linear layer diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 3e008fba59..394ea958a9 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -136,30 +136,24 @@ def forward_backward_step( latents = pack_latents(latents) target = pack_latents(noise - labels) - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=self.parallel_dims.world_mesh["cp"], - cp_buffers=[ - latents, - latent_pos_enc, - t5_encodings, - text_pos_enc, - target, - ], - cp_seq_dims=[1, 1, 1, 1, 1], - cp_no_restore_buffers={ - latents, - latent_pos_enc, - t5_encodings, - text_pos_enc, - target, - }, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + # Apply CP sharding if enabled + if self.parallel_dims.cp_enabled: + from torchtitan.distributed import utils as dist_utils + + ( + latents, + latent_pos_enc, + t5_encodings, + text_pos_enc, + target, + ), _ = dist_utils.cp_shard( + self.parallel_dims.world_mesh["cp"], + (latents, latent_pos_enc, t5_encodings, text_pos_enc, target), + None, # No attention masks for Flux + disable_load_balancer=True, ) - if self.parallel_dims.cp_enabled - else None - ) - with self.train_context(optional_context_parallel_ctx): + + with self.train_context(): with self.maybe_enable_amp: latent_noise_pred = model( img=latents, diff --git a/torchtitan/models/flux/validate.py b/torchtitan/models/flux/validate.py index 32fa7b9f55..a1aa8c8d98 100644 --- a/torchtitan/models/flux/validate.py +++ b/torchtitan/models/flux/validate.py @@ -20,6 +20,7 @@ from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.models.flux.flux_datasets import build_flux_validation_dataloader from torchtitan.models.flux.inference.sampling import generate_image, save_image +from torchtitan.models.flux.model.args import FluxModelArgs from torchtitan.models.flux.model.autoencoder import AutoEncoder from torchtitan.models.flux.model.hf_embedder import FluxEmbedder @@ -51,6 +52,7 @@ def __init__( dp_world_size: int, dp_rank: int, tokenizer: BaseTokenizer, + model_args: FluxModelArgs, parallel_dims: ParallelDims, loss_fn: LossFunction, validation_context: ValidationContext, @@ -61,6 +63,8 @@ def __init__( pp_has_last_stage: bool | None = None, ): self.job_config = job_config + self.tokenizer = tokenizer + self.model_args = model_args self.parallel_dims = parallel_dims self.loss_fn = loss_fn # pyrefly: ignore [missing-attribute] @@ -220,42 +224,33 @@ def validate( latents = pack_latents(latents) target = pack_latents(noise - labels) - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=[ - latents, - latent_pos_enc, - t5_encodings, - text_pos_enc, - target, - ], - cp_seq_dims=[1, 1, 1, 1, 1], - cp_no_restore_buffers={ - latents, - latent_pos_enc, - t5_encodings, - text_pos_enc, - target, - }, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, - ) - if parallel_dims.cp_enabled - else None + # Apply CP sharding if enabled + if parallel_dims.cp_enabled: + ( + latents, + latent_pos_enc, + t5_encodings, + text_pos_enc, + target, + ), _ = dist_utils.cp_shard( + parallel_dims.world_mesh["cp"], + (latents, latent_pos_enc, t5_encodings, text_pos_enc, target), + None, # No attention masks for Flux + disable_load_balancer=True, ) - with self.validation_context(optional_context_parallel_ctx): - with self.maybe_enable_amp: - latent_noise_pred = model( - img=latents, - img_ids=latent_pos_enc, - txt=t5_encodings, - txt_ids=text_pos_enc, - y=clip_encodings, - timesteps=timesteps, - ) + with self.validation_context(): + with self.maybe_enable_amp: + latent_noise_pred = model( + img=latents, + img_ids=latent_pos_enc, + txt=t5_encodings, + txt_ids=text_pos_enc, + y=clip_encodings, + timesteps=timesteps, + ) - loss = self.loss_fn(latent_noise_pred, target) + loss = self.loss_fn(latent_noise_pred, target) del noise, target, latent_noise_pred, latents @@ -288,6 +283,7 @@ def build_flux_validator( dp_world_size: int, dp_rank: int, tokenizer: BaseTokenizer, + model_args: FluxModelArgs, parallel_dims: ParallelDims, loss_fn: LossFunction, validation_context: ValidationContext, @@ -303,6 +299,7 @@ def build_flux_validator( dp_world_size=dp_world_size, dp_rank=dp_rank, tokenizer=tokenizer, + model_args=model_args, parallel_dims=parallel_dims, loss_fn=loss_fn, validation_context=validation_context, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 63bbc19ff6..46b27609e1 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -26,6 +26,7 @@ from torchtitan.config.job_config import Compile as CompileConfig from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.tools.logging import logger @@ -89,9 +90,14 @@ def parallelize_llama( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + cp_enabled=parallel_dims.cp_enabled, ) maybe_enable_async_tp(job_config, world_mesh["tp"]) + use_flex_attn = getattr(model.model_args, "attn_type", "sdpa") == "flex" + if parallel_dims.cp_enabled: + apply_cp(model, world_mesh["cp"], use_flex_attn) + model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) @@ -132,9 +138,6 @@ def parallelize_llama( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: @@ -154,6 +157,7 @@ def apply_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, + cp_enabled: bool = False, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -208,7 +212,8 @@ def apply_tp( layer_plan = { "attention_norm": SequenceParallel(), # NOTE: when the fourth argument (positions) is not None, its input layout - # and desired input layout should be Replicate() + # and desired input layout is still None as we don't convert freqs_cis to + # a DTensor for llama3. "attention": prepare_module_input( input_layouts=(Shard(1), None, None, None), desired_input_layouts=(Replicate(), None, None, None), diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index cafd58a52e..3d7ee40721 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -494,9 +494,6 @@ def init_weights( def _precompute_freqs_cis(self) -> torch.Tensor: return precompute_freqs_cis( self.model_args.dim // self.model_args.n_heads, - # Need to compute until at least the max token limit for generation - # TODO: explain in docs/composability.md why we removed the 2x - # relaxing in our CP enablement PR self.model_args.max_seq_len, self.model_args.rope_theta, self.model_args.rope_scaling_args, diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 112153390f..9aa7c41f92 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -26,6 +26,7 @@ from torchtitan.config.job_config import Compile as CompileConfig from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp from torchtitan.distributed.dual_pipe_v import ( DualPipeExpertParallel, get_dual_pipe_v_flag, @@ -124,6 +125,10 @@ def parallelize_llama( dual_pipe_v=dual_pipe_v, ) + use_flex_attn = getattr(model.model_args, "attn_type", "sdpa") == "flex" + if parallel_dims.cp_enabled: + apply_cp(model, world_mesh["cp"], use_flex_attn) + model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) @@ -179,9 +184,6 @@ def parallelize_llama( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: @@ -253,7 +255,8 @@ def apply_non_moe_tp( layer_plan = { "attention_norm": SequenceParallel(), # NOTE: when the fourth argument (positions) is not None, its input layout - # and desired input layout should be Replicate() + # and desired input layout is still None as we don't convert freqs_cis to + # a DTensor for llama4. "attention": prepare_module_input( input_layouts=(Shard(1), None, None, None), desired_input_layouts=(Replicate(), None, None, None), diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index e08f733f28..5421beed83 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -532,9 +532,6 @@ def init_weights( def _precompute_freqs_cis(self) -> torch.Tensor: return precompute_freqs_cis( self.model_args.dim // self.model_args.n_heads, - # Need to compute until at least the max token limit for generation - # TODO: explain in docs/composability.md why we removed the 2x - # relaxing in our CP enablement PR self.model_args.max_seq_len, self.model_args.rope_theta, self.model_args.rope_scaling_args, diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index c2eaed8de6..10a4f3a00b 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -24,6 +24,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp from torchtitan.distributed.dual_pipe_v import get_dual_pipe_v_flag from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.models.llama4.infra.parallelize import ( @@ -97,6 +98,7 @@ def parallelize_qwen3( loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + cp_enabled=parallel_dims.cp_enabled, ) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: @@ -117,6 +119,10 @@ def parallelize_qwen3( dual_pipe_v=dual_pipe_v, ) + if parallel_dims.cp_enabled: + use_flex_attn = attn_type == "flex" + apply_cp(model, world_mesh["cp"], use_flex_attn) + if job_config.activation_checkpoint.mode != "none": apply_ac( model, @@ -168,9 +174,6 @@ def parallelize_qwen3( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: @@ -197,6 +200,7 @@ def apply_non_moe_tp( loss_parallel: bool, enable_float8_tensorwise_tp: bool, enable_async_tp: bool, + cp_enabled: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -246,15 +250,19 @@ def apply_non_moe_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + positions_sharding = Replicate() if cp_enabled else None # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), - # NOTE: when the fourth argument (positions) is not None, its input layout - # and desired input layout should be Replicate() "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate(), None, None), - desired_input_layouts=(Replicate(), Replicate(), None, None), + input_layouts=(Shard(1), Replicate(), None, positions_sharding), + desired_input_layouts=( + Replicate(), + Replicate(), + None, + positions_sharding, + ), ), "attention.wq": colwise_parallel(use_local_output=False), "attention.wk": colwise_parallel(use_local_output=False), diff --git a/torchtitan/train.py b/torchtitan/train.py index 8c597cd608..c38d783fd8 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -28,6 +28,7 @@ ) from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.distributed.context_parallel import prepare_context_parallel_input from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger @@ -351,6 +352,7 @@ def __init__(self, job_config: JobConfig): dp_world_size=dp_degree, dp_rank=dp_rank, tokenizer=self.tokenizer, + model_args=self.model_args, parallel_dims=parallel_dims, loss_fn=self.loss_fn, validation_context=self.train_context, @@ -474,6 +476,15 @@ def post_dataloading_process( extra_inputs=extra_inputs, ) + if self.parallel_dims.cp_enabled: + inputs, labels, extra_kwargs = prepare_context_parallel_input( + inputs, + labels, + extra_kwargs, + self.parallel_dims.world_mesh["cp"], + self.device, + ) + return inputs, labels, extra_inputs, extra_kwargs def forward_backward_step( @@ -485,30 +496,10 @@ def forward_backward_step( inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( input_dict, labels ) - # apply context parallelism if cp is enabled - # ensure CP handles the separate freqs_cis buffer for each pp stage - cp_buffers = [inputs, labels] - cp_seq_dims = [1, 1] - if hasattr(model_parts[0], "freqs_cis"): - cp_buffers += [m.freqs_cis for m in model_parts] - cp_seq_dims += [0 for _ in model_parts] - - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - # pyrefly: ignore [bad-argument-type] - cp_buffers=cp_buffers, - cp_seq_dims=cp_seq_dims, - cp_no_restore_buffers={inputs, labels}, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, - ) - if parallel_dims.cp_enabled - else None - ) if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call - with self.train_context(optional_context_parallel_ctx): + with self.train_context(): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) ) @@ -541,8 +532,8 @@ def forward_backward_step( ) else: # Non-PP forward / backward - with self.train_context(optional_context_parallel_ctx): - assert len(model_parts) == 1 + assert len(model_parts) == 1 + with self.train_context(): with self.maybe_enable_amp: pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) loss = self.loss_fn(pred, labels)