Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,22 +544,33 @@ def cp_shard(
from torch.distributed.tensor.experimental._attention import (
_context_parallel_shard,
_HeadTailLoadBalancer,
_PTRRLoadBalancer,
)
from torch.nn.attention.flex_attention import BlockMask

INPUT_SEQ_DIM = 1
seq_len = inputs[0].size(INPUT_SEQ_DIM)
cp_world_size = cp_mesh.size(0)
if attention_masks is not None:
raise ValueError(
"FlexAttention CP is not supported yet. Will come in the next PR."
)
else:
# For SDPA, we use the _HeadTailLoadBalancer.
load_balancer = (
None
if disable_load_balancer
else _HeadTailLoadBalancer(seq_len, cp_world_size, cp_mesh.device_type)
)

load_balancer = None
if not disable_load_balancer:
if isinstance(attention_masks, BlockMask):
load_balancer = _PTRRLoadBalancer(attention_masks, cp_world_size)
elif attention_masks is None or isinstance(attention_masks, dict):
# For SDPA, we use the _HeadTailLoadBalancer.
# TODO: For dict[str, BlockMask], _PTRRLoadBalancer currently doesn't
# support the case where there are multiple masks. To address multiple
# masks usage, _PTRRLoadBalancer also needs to take into account the
# usage frequency of each mask. So we default to _HeadTailLoadBalancer
# as we have not implemented this feature.
load_balancer = _HeadTailLoadBalancer(
seq_len, cp_world_size, cp_mesh.device_type
)
else:
ValueError(
"cp_shard only support attention_masks is "
"None, BlockMask, dict[str, BlockMask]"
)

inputs = cast(
tuple[torch.Tensor, ...],
Expand All @@ -571,4 +582,27 @@ def cp_shard(
),
)

# BlockMask, has shape, [B, H, Q, KV], and we can only shard
# on the Q seq dimension, not KV.
MASK_Q_SEQ_DIM = 2
if attention_masks is not None:
assert isinstance(attention_masks, (BlockMask, dict[str, BlockMask]))
masks = (
[attention_masks]
if isinstance(attention_masks, BlockMask)
else list(attention_masks.values())
)
masks = _context_parallel_shard(
mesh=cp_mesh,
buffers=masks,
seq_dims=(MASK_Q_SEQ_DIM,) * len(masks),
load_balancer=load_balancer,
)
attention_masks = cast(
(BlockMask | dict[str, BlockMask]),
masks[0]
if isinstance(attention_masks, BlockMask)
else {k: v for k, v in zip(attention_masks.keys(), masks)},
)

return inputs, attention_masks
4 changes: 3 additions & 1 deletion torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
_score_mod_signature,
BlockMask,
create_block_mask,
flex_attention,
Expand Down Expand Up @@ -116,7 +117,8 @@ def forward(
k: torch.Tensor,
v: torch.Tensor,
*,
block_mask: BlockMask,
score_mod: _score_mod_signature | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arg not used anywhere

block_mask: BlockMask | None = None,
scale: float | None = None,
return_lse: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Expand Down
6 changes: 2 additions & 4 deletions torchtitan/models/llama3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:

if (
job_config.parallelism.context_parallel_degree > 1
and self.attn_type != "sdpa"
and self.attn_type == "varlen"
):
raise NotImplementedError(
"CP support for FlexAttention is still in progress."
)
raise NotImplementedError("CP support for varlen is not supported.")

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
return get_dense_model_nparams_and_flops(
Expand Down
Loading