Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update on "[CP] Enable FlexCP for llama3"
Summary:

Continue the previous PR, this PR enable FlexAttention + CP for llama3. FlexCP will use PTRRLoadBalancer.

Note that this PR requires pytorch/pytorch#170201

[ghstack-poisoned]
  • Loading branch information
fegin committed Dec 15, 2025
commit 08039091eef2abb1219fdd3105aad87b33bda545
26 changes: 22 additions & 4 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,17 +545,30 @@ def cp_shard(
_HeadTailLoadBalancer,
_PTRRLoadBalancer,
)
from torch.nn.attention.flex_attention import BlockMask

INPUT_SEQ_DIM = 1
seq_len = inputs[0].size(INPUT_SEQ_DIM)
cp_world_size = cp_mesh.size(0)
if isinstance(attention_masks, BlockMask):
load_balancer = _PTRRLoadBalancer(attention_masks, cp_world_size)
else:
elif attention_masks is None or isinstance(attention_masks, dict[str, BlockMask]):
# For SDPA, we use the _HeadTailLoadBalancer.
#
# TODO: For dict[str, BlockMask], _PTRRLoadBalancer currently doesn't
# support the case where there are multiple masks. To address multiple
# masks usage, _PTRRLoadBalancer also needs to take into account the
# usage frequency of each mask. So we default to _HeadTailLoadBalancer
# as we have not implemented this feature.
load_balancer = _HeadTailLoadBalancer(
seq_len, cp_world_size, cp_mesh.device_type
)
else:
ValueError(
"cp_shard only support attention_masks is "
"None, BlockMask, dict[str, BlockMask]"
)


inputs = cast(
tuple[torch.Tensor, ...],
Expand All @@ -567,7 +580,11 @@ def cp_shard(
),
)

# BlockMask, has shape, [B, H, Q, KV], and we can only shard
# on the Q seq dimension, not KV.
MASK_Q_SEQ_DIM = 2
if attention_masks is not None:
assert isinstance(attention_masks, (BlockMask, dict[str, BlockMask]))
masks = (
[attention_masks]
if isinstance(attention_masks, BlockMask)
Expand All @@ -576,13 +593,14 @@ def cp_shard(
masks = _context_parallel_shard(
mesh=cp_mesh,
buffers=masks,
seq_dims=(2,) * len(masks),
seq_dims=(MASK_Q_SEQ_DIM,) * len(masks),
load_balancer=load_balancer,
)
attention_masks = (
attention_masks = cast(
(BlockMask | dict[str, BlockMask]),
masks[0]
if isinstance(attention_masks, BlockMask)
else {k: v for k, v in zip(attention_masks.keys(), masks)}
else {k: v for k, v in zip(attention_masks.keys(), masks)},
)

return inputs, attention_masks
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.