Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP
from torchtitan.distributed.parallel_dims import ParallelDims
from torchtitan.protocols.model import AttentionMasksType
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import device_module, device_type

Expand Down Expand Up @@ -200,9 +201,6 @@ def context(cp_context: Generator[None, None, None] | None = None):
if enable_loss_parallel:
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())

if cp_context:
stack.enter_context(cp_context)

yield

return context
Expand Down Expand Up @@ -443,3 +441,65 @@ def _clip_grad_norm_with_ep(
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach)

return total_norm


def cp_shard(
cp_mesh: DeviceMesh,
inputs: torch.Tensor,
labels: torch.Tensor,
attention_masks: AttentionMasksType | None,
order_sensitive_buffers: dict[str, torch.Tensor],
order_sensitive_buffers_seq_dims: dict[str, int],
):
from torch.distributed.tensor.experimental._attention import _context_parallel_shard
from torch.distributed.tensor.experimental._load_balancer import (
_HeadTailLoadBalancer,
_PTRRLoadBalancer,
)
from torch.nn.attention.flex_attention import BlockMask

seq_len = inputs.size(1)
cp_world_size = cp_mesh.size(0)
if isinstance(attention_masks, BlockMask):
load_balancer = _PTRRLoadBalancer(attention_masks, cp_world_size)
else:
# For multiple BlockMasks or SDPA, we use the _HeadTailLoadBalancer.
load_balancer = _HeadTailLoadBalancer(
seq_len, cp_world_size, cp_mesh.device_type
)

inputs, labels = _context_parallel_shard(
mesh=cp_mesh,
buffers=(inputs, labels),
seq_dims=(1, 1),
load_balancer=load_balancer,
)

order_sensitive_buffers = _context_parallel_shard(
mesh=cp_mesh,
buffers=order_sensitive_buffers,
seq_dims=order_sensitive_buffers_seq_dims,
load_balancer=load_balancer,
)

if attention_masks is None:
return inputs, labels, None, order_sensitive_buffers

masks = (
[attention_masks]
if isinstance(attention_masks, BlockMask)
else list(attention_masks.values())
)
masks = _context_parallel_shard(
mesh=cp_mesh,
buffers=masks,
seq_dims=(2,) * len(masks),
load_balancer=load_balancer,
)
attention_masks = (
masks[0]
if isinstance(attention_masks, BlockMask)
else {k: v for k, v in zip(attention_masks.keys(), masks)}
)

return inputs, labels, attention_masks, order_sensitive_buffers
35 changes: 24 additions & 11 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@
]


class FlexAttentionKernel(torch.nn.Module):
"""Wrapper to enable FlexCP"""

_compiled_flex_attn: ClassVar[Callable] = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)

def forward(self, *args, **kwargs):
# 1. _compiled_flex_attn has to be a class variable, otherwise there will
# be multiple compiled flex_attention instances, which can be slow.
# 2. `self._compiled_flex_attn` is not correct, `self` will be passed in
# as the first argument, which will cause an error.
# `FlexAttentionKernel._compiled_flex_attn` is correct.
return FlexAttentionKernel._compiled_flex_attn(*args, **kwargs)


class FlexAttentionWrapper(torch.nn.Module):
"""Wrapper around `flex_attention` to make it torch.compile and CP compatible.

Expand All @@ -45,9 +61,11 @@ class FlexAttentionWrapper(torch.nn.Module):
block_mask as a keyword argument to be compatible with _ContextParallel.
"""

_compiled_flex_attn: ClassVar[Callable] = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)
def __init__(self) -> None:
super().__init__()
# TODO: remove this wrapper once FlexAttentionWrapper.forward() has the
# same signature as flex_attention() and is compatible with _ContextParallel.
self._flex_attention_kernel = FlexAttentionKernel()

def forward(
self,
Expand All @@ -59,15 +77,10 @@ def forward(
scale: float | None = None,
return_lse: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# 1. _compiled_flex_attn has to be a class variable, otherwise there will
# be multiple compiled flex_attention instances, which can be slow.
# 2. `self._compiled_flex_attn` is not correct, `self` will be passed in
# as the first argument, which will cause an error.
# `FlexAttentionWrapper._compiled_flex_attn` is correct.
# 3. Used `return_lse` instead of `return_aux` because of easier TP module notation
# to convert `lse` to be DTensor.
# Used `return_lse` instead of `return_aux` because of easier TP module notation
# to convert `lse` to be DTensor.

return FlexAttentionWrapper._compiled_flex_attn(
return self._flex_attention_kernel(
q,
k,
v,
Expand Down
55 changes: 48 additions & 7 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,6 @@ def parallelize_llama(
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
"""

use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
raise NotImplementedError("CP support for FlexAttention is still in progress.")

if parallel_dims.tp_enabled:
enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in (
Expand All @@ -91,6 +87,11 @@ def parallelize_llama(
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
if parallel_dims.cp_enabled:
logger.info("Applied Context Parallel to the model")
apply_cp(model, world_mesh["cp"], use_flex_attn)

model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
Expand Down Expand Up @@ -131,9 +132,6 @@ def parallelize_llama(
else:
logger.info("Applied FSDP to the model")

if parallel_dims.cp_enabled:
logger.info("Applied Context Parallel to the model")

if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
Expand Down Expand Up @@ -328,3 +326,46 @@ def apply_ddp(
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)

logger.info("Applied DDP to the model")


def apply_cp(
model: nn.Module,
cp_mesh: DeviceMesh,
use_flex_attn: bool,
) -> None:
"""
Apply context parallelism to the model.
"""
from torch.distributed.tensor.experimental._attention import (
_ContextParallel,
_enable_context_parallel_dispatcher,
)

# Apply context parallelism to every transformer block
# TODO: make seq_sim configurable once the implementation doesn't assume 2
# internally.
if use_flex_attn:
cp_plan = _ContextParallel(
seq_dim=2, attention_type=_ContextParallel.AttentionType.FLEX
)
else:
# This is currently required as DTensor dispatcher is not enabled to
# dispatch SDPA to CP implementation. We don't disable the CP
# dispatching in TorchTitan as it is not needed. But there is a
# corresponding API, _disable_context_parallel_dispatcher to do
# that if users have this use case.
_enable_context_parallel_dispatcher()
cp_plan = _ContextParallel(
seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA
)

for transformer_block in model.layers.values():
module = transformer_block.attention.inner_attention
if use_flex_attn:
module = module._flex_attention_kernel

parallelize_module(
module=module,
device_mesh=cp_mesh,
parallelize_plan=cp_plan,
)
5 changes: 0 additions & 5 deletions torchtitan/models/llama3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)
self.max_seq_len = seq_len

if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
raise NotImplementedError(
"CP support for FlexAttention is still in progress."
)

def get_nparams_and_flops(
self, model: nn.Module, seq_len: int
) -> tuple[int, float]:
Expand Down
20 changes: 14 additions & 6 deletions torchtitan/models/llama3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.

The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
and the first seqlen elements will be sliced, but dim must match x.
The input freqs_cis tensor is assumed to be of shape (batch_size, seqlen, dim).

Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
Expand All @@ -104,10 +103,10 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
"""
ndim = x.ndim
assert ndim > 1
batch_size = x.shape[0]
seqlen = x.shape[1]
freqs_cis = freqs_cis[0:seqlen]
assert freqs_cis.shape == (seqlen, x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
assert freqs_cis.shape == (batch_size, seqlen, x.shape[-1])
shape = [d if i in (0, 1, ndim - 1) else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


Expand Down Expand Up @@ -474,9 +473,18 @@ def get_attention_masks(
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
)

def get_order_sensitive_buffers(
self,
batch_size: int,
seq_len: int,
) -> tuple[dict[str, torch.Tensor], dict[str, int]]:
freqs_cis = self.freqs_cis[:seq_len].repeat(batch_size, 1, 1)
return ({"freqs_cis": freqs_cis}, {"freqs_cis": 1})

def forward(
self,
tokens: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
):
"""
Expand All @@ -496,7 +504,7 @@ def forward(
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

for layer in self.layers.values():
h = layer(h, self.freqs_cis, attention_masks=attention_masks)
h = layer(h, freqs_cis, attention_masks=attention_masks)

h = self.norm(h) if self.norm else h
output = self.output(h) if self.output else h
Expand Down
7 changes: 7 additions & 0 deletions torchtitan/protocols/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,10 @@ def get_attention_masks(
raise NotImplementedError(
"This model does not support attention masking/Flex Attention."
)

def get_order_sensitive_buffers(
self,
batch_size: int,
seq_len: int,
) -> tuple[dict[str, torch.Tensor], dict[str, int]]:
return ({}, {})
63 changes: 41 additions & 22 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,10 @@ def batch_generator(

yield input_dict, labels

def forward_backward_step(
def post_dataloader_step(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method.

self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims

) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any], dict[str, Any],]:
"""Post processing of the batch and label after being loaded from the dataloader."""
inputs = input_dict["input"]
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
# For arguments, like attention_masks, we have to put them in a separate
Expand All @@ -423,32 +421,53 @@ def forward_backward_step(
extra_kwargs = {}

if getattr(self.model_args, "use_flex_attn", False):
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
input_batch=inputs,
tokenizer=self.tokenizer,
extra_inputs=extra_inputs,
)
else:
extra_kwargs["attention_masks"] = None

# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=parallel_dims.world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
)
if parallel_dims.cp_enabled
# Get the order sensitive buffers
order_sensitive_buffers = self.model_parts[0].get_order_sensitive_buffers(
inputs.size(0), inputs.size(1)
)
cp_mesh = (
self.parallel_dims.world_mesh["cp"]
if self.parallel_dims.cp_enabled
else None
)
if cp_mesh:
(
inputs,
labels,
extra_kwargs["attention_masks"],
*order_sensitive_buffers,
) = dist_utils.cp_shard(
cp_mesh,
inputs,
labels,
extra_kwargs["attention_masks"],
*order_sensitive_buffers,
)
extra_kwargs.update(order_sensitive_buffers[0])
return inputs, labels, extra_inputs, extra_kwargs

def forward_backward_step(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims

inputs, labels, extra_inputs, extra_kwargs = self.post_dataloader_step(
input_dict, labels
)

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
with self.train_context(optional_context_parallel_ctx):
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
)
targets, losses = (labels, []) if self.pp_has_last_stage else (None, None)
with self.train_context():
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs,
Expand Down Expand Up @@ -478,7 +497,7 @@ def forward_backward_step(
)
else:
# Non-PP forward / backward
with self.train_context(optional_context_parallel_ctx):
with self.train_context():
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
Expand Down
Loading