Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address comments
  • Loading branch information
fegin committed Apr 2, 2025
commit 47753cb72251e6202a73ab26ffd6986ddbd8241f
13 changes: 6 additions & 7 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
)



class FlexAttn(torch.nn.Module):
class FlexAttention(torch.nn.Module):
# We registered flex_attention related attributes as class variables as we
# need to amortize the cost of compilation.
flex_attn: ClassVar[Callable] = torch.compile(
Expand All @@ -41,13 +40,13 @@ def __init__(self, attn_mask_type: str) -> None:
if attn_mask_type not in ["causal", "block_causal"]:
raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.")
self.attn_mask_type = attn_mask_type
FlexAttn.used_attn_mask_types.add(attn_mask_type)
FlexAttention.used_attn_mask_types.add(attn_mask_type)

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
block_mask = FlexAttn.block_masks[self.attn_mask_type]
return FlexAttn.flex_attn(q, k, v, block_mask=block_mask)
block_mask = FlexAttention.block_masks[self.attn_mask_type]
return FlexAttention.flex_attn(q, k, v, block_mask=block_mask)

@classmethod
def _get_causal_mask_fn(cls) -> Callable:
Expand Down Expand Up @@ -112,10 +111,10 @@ def forward(

def build_attention(use_flex_attn: bool, attn_mask_type: str):
if use_flex_attn:
return FlexAttn(attn_mask_type)
return FlexAttention(attn_mask_type)
else:
return ScaledDotProductAttention(attn_mask_type)


def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None:
FlexAttn.init_attention_mask(batch, eos_id)
FlexAttention.init_attention_mask(batch, eos_id)
2 changes: 2 additions & 0 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ def forward(self, tokens: torch.Tensor):
torch.Tensor: Output logits after applying the Transformer model.

"""
# TODO: We will to change forward() signature to allow tokens to
# be always passed in.
if self.model_args.use_flex_attn:
init_attention_mask(tokens, eos_id=self.eos_id)

Expand Down