-
Notifications
You must be signed in to change notification settings - Fork 645
[CP] Refactor Context Parallel to use new PyTorch CP APIs #2144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/fegin/53/base
Are you sure you want to change the base?
Changes from all commits
851cf7a
b5eceef
00628f0
16e8836
79fb0a8
26c1b49
d4b93d0
dd1c4ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
|
|
||
| from collections.abc import Callable | ||
| from contextlib import AbstractContextManager | ||
| from typing import TypeAlias | ||
| from typing import Any, TypeAlias | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
@@ -17,14 +17,13 @@ | |
| from torchtitan.components.tokenizer import BaseTokenizer | ||
| from torchtitan.config import JobConfig | ||
| from torchtitan.distributed import ParallelDims, utils as dist_utils | ||
| from torchtitan.distributed.context_parallel import prepare_context_parallel_input | ||
| from torchtitan.hf_datasets.text_datasets import build_text_validation_dataloader | ||
| from torchtitan.protocols.model import BaseModelArgs | ||
| from torchtitan.tools import utils | ||
| from torchtitan.tools.logging import logger | ||
|
|
||
| ValidationContext: TypeAlias = Callable[ | ||
| [AbstractContextManager[None] | None], | ||
| AbstractContextManager[None], | ||
| ] | ||
| ValidationContext: TypeAlias = Callable[[], AbstractContextManager[None]] | ||
|
|
||
|
|
||
| class BaseValidator: | ||
|
|
@@ -57,6 +56,7 @@ def __init__( | |
| dp_world_size: int, | ||
| dp_rank: int, | ||
| tokenizer: BaseTokenizer, | ||
| model_args: BaseModelArgs, | ||
| parallel_dims: ParallelDims, | ||
| loss_fn: LossFunction, | ||
| validation_context: ValidationContext, | ||
|
|
@@ -67,6 +67,8 @@ def __init__( | |
| pp_has_last_stage: bool | None = None, | ||
| ): | ||
| self.job_config = job_config | ||
| self.tokenizer = tokenizer | ||
| self.model_args = model_args | ||
| self.parallel_dims = parallel_dims | ||
| self.loss_fn = loss_fn | ||
| self.validation_dataloader = build_text_validation_dataloader( | ||
|
|
@@ -89,6 +91,71 @@ def __init__( | |
| "unequal sample counts across ranks when dataset is exhausted." | ||
| ) | ||
|
|
||
| def post_dataloading_process( | ||
| self, | ||
| input_dict: dict[str, torch.Tensor], | ||
| labels: torch.Tensor, | ||
| device: torch.device, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we just use |
||
| model_parts: list[nn.Module], | ||
| ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: | ||
| """ | ||
| Post-processing hook after data loading and before model forward pass. | ||
|
|
||
| This method processes the raw data from the dataloader and prepares it for | ||
| the model's forward pass. It separates the main input tensor from auxiliary | ||
| inputs and constructs additional keyword arguments (e.g., attention masks). | ||
|
|
||
| Args: | ||
| input_dict: Dictionary containing tensors from the dataloader. Must | ||
| contain an "input" key with the main input tensor. May contain | ||
| additional keys for auxiliary inputs (e.g., position ids). | ||
| labels: Target labels for the batch. | ||
| device: Device to use for creating new tensors (e.g., positions). | ||
| model_parts: List of model parts for accessing model methods. | ||
|
|
||
| Returns: | ||
| A tuple of (inputs, labels, extra_inputs, extra_kwargs) where: | ||
| - inputs: Main input tensor extracted from input_dict["input"]. | ||
| - labels: Target labels (potentially modified by CP sharding). | ||
| - extra_inputs: Dict of auxiliary input tensors (all keys except | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we can consolidate this into |
||
| "input" from input_dict). These are passed to the model forward | ||
| but are NOT forwarded across pipeline parallel stages. | ||
| - extra_kwargs: Dict of additional keyword arguments for model forward. | ||
| These ARE forwarded across pipeline parallel stages. Contains | ||
| attention_masks if flex attention is enabled. | ||
|
|
||
| Note: | ||
| The distinction between extra_inputs and extra_kwargs is important for | ||
| pipeline parallelism: extra_kwargs are forwarded to all pipeline stages, | ||
| while extra_inputs are only available to the first stage. | ||
| """ | ||
| inputs = input_dict["input"] | ||
| extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} | ||
| # For arguments, like attention_masks, we have to put them in a separate | ||
| # dict as extra_inputs are not forwarded to other stages in PP, but | ||
| # extra_kwargs are. | ||
| extra_kwargs: dict[str, Any] = {} | ||
|
|
||
| attn_type = getattr(self.model_args, "attn_type", "sdpa") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems we introduce model_args to validator because of this line. This condition is checked again in The argument is that validator is not supposed to know ModelArgs details. |
||
| if attn_type in ["flex", "varlen"]: | ||
| # pyrefly: ignore [not-callable] | ||
| extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( | ||
| input_batch=inputs, | ||
| tokenizer=self.tokenizer, | ||
| extra_inputs=extra_inputs, | ||
| ) | ||
|
|
||
| if self.parallel_dims.cp_enabled: | ||
| inputs, labels, extra_kwargs = prepare_context_parallel_input( | ||
| inputs, | ||
| labels, | ||
| extra_kwargs, | ||
| self.parallel_dims.world_mesh["cp"], | ||
| device, | ||
| ) | ||
|
|
||
| return inputs, labels, extra_inputs, extra_kwargs | ||
|
|
||
| @torch.no_grad() | ||
| # pyrefly: ignore [bad-override] | ||
| def validate( | ||
|
|
@@ -117,38 +184,39 @@ def validate( | |
| self.metrics_processor.ntokens_since_last_log += labels.numel() | ||
| for k, v in input_dict.items(): | ||
| input_dict[k] = v.to(device_type) | ||
| inputs = input_dict["input"] | ||
| labels = labels.to(device_type) | ||
|
|
||
| optional_context_parallel_ctx = ( | ||
| dist_utils.create_context_parallel_ctx( | ||
| cp_mesh=parallel_dims.world_mesh["cp"], | ||
| cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], | ||
| cp_seq_dims=[1, 1] + [0 for _ in model_parts], | ||
| cp_no_restore_buffers={inputs, labels}, | ||
| cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, | ||
| ) | ||
| if parallel_dims.cp_enabled | ||
| else None | ||
| # Create device object for post_dataloading_process | ||
| device = torch.device(device_type) | ||
|
|
||
| # Process data (extract inputs, handle attention masks, CP sharding) | ||
| inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( | ||
| input_dict, labels, device, model_parts | ||
| ) | ||
|
|
||
| if parallel_dims.pp_enabled: | ||
| assert self.pp_schedule is not None | ||
| assert self.pp_has_first_stage is not None | ||
| assert self.pp_has_last_stage is not None | ||
| # Pipeline Parallel forward inside eval() call | ||
| with self.validation_context(optional_context_parallel_ctx): | ||
| with self.validation_context(): | ||
| targets, losses = ( | ||
| (labels, []) if self.pp_has_last_stage else (None, None) | ||
| ) | ||
| if self.pp_has_first_stage: | ||
| self.pp_schedule.eval( | ||
| inputs, | ||
| **extra_inputs, | ||
| **extra_kwargs, | ||
| target=targets, | ||
| losses=losses, | ||
| ) | ||
| else: | ||
| self.pp_schedule.eval(target=targets, losses=losses) | ||
| self.pp_schedule.eval( | ||
| **extra_kwargs, | ||
| target=targets, | ||
| losses=losses, | ||
| ) | ||
|
|
||
| # accumulate losses across pipeline microbatches | ||
| # TODO: PP+FSDP unexpectedly puts the loss back to the CPU | ||
|
|
@@ -161,10 +229,12 @@ def validate( | |
| else torch.tensor([-1.0], device=device_type) | ||
| ) | ||
| else: | ||
| with self.validation_context(optional_context_parallel_ctx): | ||
| with self.validation_context(): | ||
| assert len(model_parts) == 1 | ||
| with self.maybe_enable_amp: | ||
| predictions = model_parts[0](inputs) | ||
| predictions = model_parts[0]( | ||
| inputs, **extra_inputs, **extra_kwargs | ||
| ) | ||
| loss = self.loss_fn(predictions, labels) | ||
|
|
||
| accumulated_losses.append(loss.detach()) | ||
|
|
@@ -193,6 +263,7 @@ def build_validator( | |
| dp_world_size: int, | ||
| dp_rank: int, | ||
| tokenizer: BaseTokenizer, | ||
| model_args: BaseModelArgs, | ||
| parallel_dims: ParallelDims, | ||
| loss_fn: LossFunction, | ||
| validation_context: ValidationContext, | ||
|
|
@@ -208,6 +279,7 @@ def build_validator( | |
| dp_world_size=dp_world_size, | ||
| dp_rank=dp_rank, | ||
| tokenizer=tokenizer, | ||
| model_args=model_args, | ||
| parallel_dims=parallel_dims, | ||
| loss_fn=loss_fn, | ||
| validation_context=validation_context, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,116 @@ | ||||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
| # All rights reserved. | ||||||
| # | ||||||
| # This source code is licensed under the BSD-style license found in the | ||||||
| # LICENSE file in the root directory of this source tree. | ||||||
|
|
||||||
| from typing import Any | ||||||
|
|
||||||
| import torch | ||||||
| import torch.nn as nn | ||||||
| from torch.distributed.device_mesh import DeviceMesh | ||||||
| from torch.distributed.tensor.experimental._attention import ( | ||||||
| _ContextParallel, | ||||||
| _enable_context_parallel_dispatcher, | ||||||
| ) | ||||||
| from torch.distributed.tensor.parallel import parallelize_module | ||||||
|
|
||||||
| from torchtitan.distributed import utils as dist_utils | ||||||
| from torchtitan.tools.logging import logger | ||||||
|
|
||||||
|
|
||||||
| def apply_cp( | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This name sounds too generic for what it's assuming, e.g. you needed a different one for flux. Maybe |
||||||
| model: nn.Module, | ||||||
| cp_mesh: DeviceMesh, | ||||||
| use_flex_attn: bool, | ||||||
| ) -> None: | ||||||
| """ | ||||||
| Apply context parallelism to the model. | ||||||
| Context Parallelism (CP) splits the sequence dimension across devices to enable | ||||||
| training with longer sequences. This function applies CP to the attention modules | ||||||
| of all transformer blocks in the model. | ||||||
| Args: | ||||||
| model: The transformer model with layers containing attention modules | ||||||
| cp_mesh: Device mesh for context parallel dimension | ||||||
| use_flex_attn: Whether the model uses FlexAttention (True) or SDPA (False) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. n00b q: We don't consider varlen + CP here? why is that? |
||||||
| Note: | ||||||
| - For FlexAttention: CP plan uses FLEX attention type | ||||||
| - For SDPA: Enables CP dispatcher and uses SDPA attention type | ||||||
| - Applies to transformer_block.attention.inner_attention for each layer | ||||||
| """ | ||||||
| # Apply context parallelism to every transformer block | ||||||
| # TODO: make seq_sim configurable once the implementation doesn't assume 2 | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| # internally. | ||||||
| if use_flex_attn: | ||||||
| cp_plan = _ContextParallel( | ||||||
| seq_dim=2, attention_type=_ContextParallel.AttentionType.FLEX | ||||||
| ) | ||||||
| else: | ||||||
| # This is currently required as DTensor dispatcher is not enabled to | ||||||
| # dispatch SDPA to CP implementation. We don't disable the CP | ||||||
| # dispatching in TorchTitan as it is not needed. But there is a | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment is a little bit confusing to me - DTensor dispatcher is not enabled to dispatch SDPA to CP implementation, so we explicitly enable it by calling |
||||||
| # corresponding API, _disable_context_parallel_dispatcher to do | ||||||
| # that if users have this use case. | ||||||
| _enable_context_parallel_dispatcher() | ||||||
| cp_plan = _ContextParallel( | ||||||
| seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA | ||||||
| ) | ||||||
|
|
||||||
| # pyrefly: ignore [not-callable] | ||||||
| for transformer_block in model.layers.values(): | ||||||
| parallelize_module( | ||||||
| # pyrefly: ignore [missing-attribute] | ||||||
| module=transformer_block.attention.inner_attention, | ||||||
| device_mesh=cp_mesh, | ||||||
| parallelize_plan=cp_plan, | ||||||
| ) | ||||||
|
|
||||||
| logger.info("Applied Context Parallel to the model") | ||||||
|
|
||||||
|
|
||||||
| def prepare_context_parallel_input( | ||||||
| inputs: torch.Tensor, | ||||||
| labels: torch.Tensor, | ||||||
| extra_kwargs: dict[str, Any], | ||||||
| cp_mesh: DeviceMesh, | ||||||
| device: torch.device, | ||||||
| ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: | ||||||
| """ | ||||||
| Prepare inputs, labels, and attention masks for Context Parallel forward pass. | ||||||
| This function prepares tensors for Context Parallel by: | ||||||
| 1. Creating position indices based on input sequence length | ||||||
| 2. Sharding inputs, labels, and positions across the CP mesh | ||||||
| 3. Sharding attention masks if present | ||||||
| Args: | ||||||
| inputs: Input tensor of shape [batch_size, seq_len] | ||||||
| labels: Label tensor of shape [batch_size, seq_len] | ||||||
| extra_kwargs: Dictionary that may contain 'attention_masks' to be sharded | ||||||
| cp_mesh: Device mesh for context parallel dimension | ||||||
| device: Device to create position tensor on | ||||||
| Returns: | ||||||
| Tuple of (sharded_inputs, sharded_labels, updated_extra_kwargs) where: | ||||||
| - sharded_inputs: Inputs sharded along sequence dimension | ||||||
| - sharded_labels: Labels sharded along sequence dimension | ||||||
| - updated_extra_kwargs: Dict with sharded 'positions' and optionally | ||||||
| sharded 'attention_masks' | ||||||
| """ | ||||||
| attention_masks = extra_kwargs.get("attention_masks", None) | ||||||
| positions = torch.arange( | ||||||
| 0, inputs.shape[1], dtype=torch.int32, device=device | ||||||
| ).expand(inputs.shape) | ||||||
| (inputs, labels, positions), attention_masks = dist_utils.cp_shard( | ||||||
| cp_mesh, | ||||||
| (inputs, labels, positions), | ||||||
| attention_masks, | ||||||
| ) | ||||||
| extra_kwargs["positions"] = positions | ||||||
| if attention_masks is not None: | ||||||
| extra_kwargs["attention_masks"] = attention_masks | ||||||
|
|
||||||
| return inputs, labels, extra_kwargs | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we should consolidate this into
input_dict["labels"]