Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
fegin committed Dec 11, 2025
commit 4c06750c96726ad8d91a25cfaf257149c1bd41fe
23 changes: 20 additions & 3 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,16 +523,15 @@ def cp_shard(
from torch.distributed.tensor.experimental._attention import (
_context_parallel_shard,
_HeadTailLoadBalancer,
_PTRRLoadBalancer,
)
from torch.nn.attention.flex_attention import BlockMask

INPUT_SEQ_DIM = 1
seq_len = inputs[0].size(INPUT_SEQ_DIM)
cp_world_size = cp_mesh.size(0)
if isinstance(attention_masks, BlockMask):
raise ValueError(
"FlexAttention CP is not supported yet. Will come in the next PR."
)
load_balancer = _PTRRLoadBalancer(attention_masks, cp_world_size)
else:
# For SDPA, we use the _HeadTailLoadBalancer.
load_balancer = _HeadTailLoadBalancer(
Expand All @@ -546,4 +545,22 @@ def cp_shard(
load_balancer=load_balancer,
)

if attention_masks is not None:
masks = (
[attention_masks]
if isinstance(attention_masks, BlockMask)
else list(attention_masks.values())
)
masks = _context_parallel_shard(
mesh=cp_mesh,
buffers=masks,
seq_dims=(2,) * len(masks),
load_balancer=load_balancer,
)
attention_masks = (
masks[0]
if isinstance(attention_masks, BlockMask)
else {k: v for k, v in zip(attention_masks.keys(), masks)}
)

return inputs, attention_masks
4 changes: 3 additions & 1 deletion torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
_score_mod_signature,
BlockMask,
create_block_mask,
flex_attention,
Expand Down Expand Up @@ -113,7 +114,8 @@ def forward(
k: torch.Tensor,
v: torch.Tensor,
*,
block_mask: BlockMask,
score_mod: _score_mod_signature | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arg not used anywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need score_mod to ensure the function signature is the same as flex_attention so that the module hooks know how to allgather K, V.

block_mask: BlockMask | None = None,
scale: float | None = None,
return_lse: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Expand Down
6 changes: 2 additions & 4 deletions torchtitan/models/llama3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:

if (
job_config.parallelism.context_parallel_degree > 1
and self.attn_type != "sdpa"
and self.attn_type == "varlen"
):
raise NotImplementedError(
"CP support for FlexAttention is still in progress."
)
raise NotImplementedError("CP support for varlen is not supported.")

def get_nparams_and_flops(
self, model: nn.Module, seq_len: int
Expand Down
Loading