-
Notifications
You must be signed in to change notification settings - Fork 644
[CP] Refactor Context Parallel to use new PyTorch CP APIs #2144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/fegin/53/base
Are you sure you want to change the base?
Changes from 4 commits
851cf7a
b5eceef
00628f0
16e8836
79fb0a8
26c1b49
d4b93d0
dd1c4ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import Generator | ||
| from typing import Any, Generator | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
@@ -16,6 +16,7 @@ | |
| from torchtitan.config import JobConfig | ||
| from torchtitan.distributed import ParallelDims, utils as dist_utils | ||
| from torchtitan.hf_datasets.text_datasets import build_text_validation_dataloader | ||
| from torchtitan.protocols.model import BaseModelArgs | ||
| from torchtitan.tools import utils | ||
| from torchtitan.tools.logging import logger | ||
|
|
||
|
|
@@ -50,6 +51,7 @@ def __init__( | |
| dp_world_size: int, | ||
| dp_rank: int, | ||
| tokenizer: BaseTokenizer, | ||
| model_args: BaseModelArgs, | ||
| parallel_dims: ParallelDims, | ||
| loss_fn: LossFunction, | ||
| validation_context: Generator[None, None, None], | ||
|
|
@@ -60,6 +62,8 @@ def __init__( | |
| pp_has_last_stage: bool | None = None, | ||
| ): | ||
| self.job_config = job_config | ||
| self.tokenizer = tokenizer | ||
| self.model_args = model_args | ||
| self.parallel_dims = parallel_dims | ||
| self.loss_fn = loss_fn | ||
| self.validation_dataloader = build_text_validation_dataloader( | ||
|
|
@@ -82,6 +86,75 @@ def __init__( | |
| "unequal sample counts across ranks when dataset is exhausted." | ||
| ) | ||
|
|
||
| def post_dataloading_process( | ||
| self, | ||
| input_dict: dict[str, torch.Tensor], | ||
| labels: torch.Tensor, | ||
| device: torch.device, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we just use |
||
| model_parts: list[nn.Module], | ||
| ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: | ||
| """ | ||
| Post-processing hook after data loading and before model forward pass. | ||
|
|
||
| This method processes the raw data from the dataloader and prepares it for | ||
| the model's forward pass. It separates the main input tensor from auxiliary | ||
| inputs and constructs additional keyword arguments (e.g., attention masks). | ||
|
|
||
| Args: | ||
| input_dict: Dictionary containing tensors from the dataloader. Must | ||
| contain an "input" key with the main input tensor. May contain | ||
| additional keys for auxiliary inputs (e.g., position ids). | ||
| labels: Target labels for the batch. | ||
| device: Device to use for creating new tensors (e.g., positions). | ||
| model_parts: List of model parts for accessing model methods. | ||
|
|
||
| Returns: | ||
| A tuple of (inputs, labels, extra_inputs, extra_kwargs) where: | ||
| - inputs: Main input tensor extracted from input_dict["input"]. | ||
| - labels: Target labels (potentially modified by CP sharding). | ||
| - extra_inputs: Dict of auxiliary input tensors (all keys except | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we can consolidate this into |
||
| "input" from input_dict). These are passed to the model forward | ||
| but are NOT forwarded across pipeline parallel stages. | ||
| - extra_kwargs: Dict of additional keyword arguments for model forward. | ||
| These ARE forwarded across pipeline parallel stages. Contains | ||
| attention_masks if flex attention is enabled. | ||
|
|
||
| Note: | ||
| The distinction between extra_inputs and extra_kwargs is important for | ||
| pipeline parallelism: extra_kwargs are forwarded to all pipeline stages, | ||
| while extra_inputs are only available to the first stage. | ||
| """ | ||
| inputs = input_dict["input"] | ||
| extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} | ||
| # For arguments, like attention_masks, we have to put them in a separate | ||
| # dict as extra_inputs are not forwarded to other stages in PP, but | ||
| # extra_kwargs are. | ||
| extra_kwargs: dict[str, Any] = {} | ||
|
|
||
| attn_type = getattr(self.model_args, "attn_type", "sdpa") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems we introduce model_args to validator because of this line. This condition is checked again in The argument is that validator is not supposed to know ModelArgs details. |
||
| if attn_type in ["flex", "varlen"]: | ||
| extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( | ||
| input_batch=inputs, | ||
| tokenizer=self.tokenizer, | ||
| extra_inputs=extra_inputs, | ||
| ) | ||
|
|
||
| if self.parallel_dims.cp_enabled: | ||
| attention_masks = extra_kwargs.get("attention_masks", None) | ||
| positions = torch.arange( | ||
| 0, inputs.shape[1], dtype=torch.int32, device=device | ||
| ).expand(inputs.shape) | ||
| (inputs, labels, positions), attention_masks = dist_utils.cp_shard( | ||
| self.parallel_dims.world_mesh["cp"], | ||
| (inputs, labels, positions), | ||
| attention_masks, | ||
| ) | ||
| extra_kwargs["positions"] = positions | ||
| if attention_masks is not None: | ||
| extra_kwargs["attention_masks"] = attention_masks | ||
|
|
||
| return inputs, labels, extra_inputs, extra_kwargs | ||
|
|
||
| @torch.no_grad() | ||
| def validate( | ||
| self, | ||
|
|
@@ -108,38 +181,39 @@ def validate( | |
| self.metrics_processor.ntokens_since_last_log += labels.numel() | ||
| for k, v in input_dict.items(): | ||
| input_dict[k] = v.to(device_type) | ||
| inputs = input_dict["input"] | ||
| labels = labels.to(device_type) | ||
|
|
||
| optional_context_parallel_ctx = ( | ||
| dist_utils.create_context_parallel_ctx( | ||
| cp_mesh=parallel_dims.world_mesh["cp"], | ||
| cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], | ||
| cp_seq_dims=[1, 1] + [0 for _ in model_parts], | ||
| cp_no_restore_buffers={inputs, labels}, | ||
| cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, | ||
| ) | ||
| if parallel_dims.cp_enabled | ||
| else None | ||
| # Create device object for post_dataloading_process | ||
| device = torch.device(device_type) | ||
|
|
||
| # Process data (extract inputs, handle attention masks, CP sharding) | ||
| inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( | ||
| input_dict, labels, device, model_parts | ||
| ) | ||
|
|
||
| if parallel_dims.pp_enabled: | ||
| assert self.pp_schedule is not None | ||
| assert self.pp_has_first_stage is not None | ||
| assert self.pp_has_last_stage is not None | ||
| # Pipeline Parallel forward inside eval() call | ||
| with self.validation_context(optional_context_parallel_ctx): | ||
| with self.validation_context(): | ||
| targets, losses = ( | ||
| (labels, []) if self.pp_has_last_stage else (None, None) | ||
| ) | ||
| if self.pp_has_first_stage: | ||
| self.pp_schedule.eval( | ||
| inputs, | ||
| **extra_inputs, | ||
| **extra_kwargs, | ||
| target=targets, | ||
| losses=losses, | ||
| ) | ||
| else: | ||
| self.pp_schedule.eval(target=targets, losses=losses) | ||
| self.pp_schedule.eval( | ||
| **extra_kwargs, | ||
| target=targets, | ||
| losses=losses, | ||
| ) | ||
|
|
||
| # accumulate losses across pipeline microbatches | ||
| # TODO: PP+FSDP unexpectedly puts the loss back to the CPU | ||
|
|
@@ -152,10 +226,12 @@ def validate( | |
| else torch.tensor([-1.0], device=device_type) | ||
| ) | ||
| else: | ||
| with self.validation_context(optional_context_parallel_ctx): | ||
| with self.validation_context(): | ||
| assert len(model_parts) == 1 | ||
| with self.maybe_enable_amp: | ||
| predictions = model_parts[0](inputs) | ||
| predictions = model_parts[0]( | ||
| inputs, **extra_inputs, **extra_kwargs | ||
| ) | ||
| loss = self.loss_fn(predictions, labels) | ||
|
|
||
| accumulated_losses.append(loss.detach()) | ||
|
|
@@ -184,6 +260,7 @@ def build_validator( | |
| dp_world_size: int, | ||
| dp_rank: int, | ||
| tokenizer: BaseTokenizer, | ||
| model_args: BaseModelArgs, | ||
| parallel_dims: ParallelDims, | ||
| loss_fn: LossFunction, | ||
| validation_context: Generator[None, None, None], | ||
|
|
@@ -199,6 +276,7 @@ def build_validator( | |
| dp_world_size=dp_world_size, | ||
| dp_rank=dp_rank, | ||
| tokenizer=tokenizer, | ||
| model_args=model_args, | ||
| parallel_dims=parallel_dims, | ||
| loss_fn=loss_fn, | ||
| validation_context=validation_context, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
|
|
||
| from torchtitan.config import Comm as CommConfig, Debug as DebugConfig, TORCH_DTYPE_MAP | ||
| from torchtitan.distributed.parallel_dims import ParallelDims | ||
| from torchtitan.protocols.model import AttentionMasksType | ||
| from torchtitan.tools.logging import logger | ||
| from torchtitan.tools.utils import device_module, device_type | ||
|
|
||
|
|
@@ -222,14 +223,11 @@ def create_context_parallel_ctx( | |
|
|
||
| def get_train_context(enable_loss_parallel: bool) -> Generator[None, None, None]: | ||
| @contextlib.contextmanager | ||
| def context(cp_context: Generator[None, None, None] | None = None): | ||
| def context(): | ||
| with contextlib.ExitStack() as stack: | ||
| if enable_loss_parallel: | ||
| stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) | ||
|
|
||
| if cp_context: | ||
| stack.enter_context(cp_context) | ||
|
|
||
| yield | ||
|
|
||
| return context | ||
|
|
@@ -515,3 +513,37 @@ def _clip_grad_norm_with_ep( | |
| torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach) | ||
|
|
||
| return total_norm | ||
|
|
||
|
|
||
| def cp_shard( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this function also go to context_parallel.py? maybe we can consolidate with |
||
| cp_mesh: DeviceMesh, | ||
| inputs: torch.Tensor, | ||
| attention_masks: AttentionMasksType | None, | ||
| ): | ||
| from torch.distributed.tensor.experimental._attention import ( | ||
| _context_parallel_shard, | ||
| _HeadTailLoadBalancer, | ||
| ) | ||
| from torch.nn.attention.flex_attention import BlockMask | ||
|
|
||
| INPUT_SEQ_DIM = 1 | ||
| seq_len = inputs[0].size(INPUT_SEQ_DIM) | ||
| cp_world_size = cp_mesh.size(0) | ||
| if isinstance(attention_masks, BlockMask): | ||
fegin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raise ValueError( | ||
| "FlexAttention CP is not supported yet. Will come in the next PR." | ||
| ) | ||
| else: | ||
fegin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # For SDPA, we use the _HeadTailLoadBalancer. | ||
| load_balancer = _HeadTailLoadBalancer( | ||
| seq_len, cp_world_size, cp_mesh.device_type | ||
| ) | ||
|
|
||
| inputs = _context_parallel_shard( | ||
| mesh=cp_mesh, | ||
| buffers=inputs, | ||
| seq_dims=tuple(1 for _ in inputs), | ||
| load_balancer=load_balancer, | ||
| ) | ||
|
|
||
| return inputs, attention_masks | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we should consolidate this into
input_dict["labels"]