Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Allow mix usages of block masks
  • Loading branch information
fegin committed Apr 2, 2025
commit 4d6393991ef4c9aa9fb8fa5545d23356d44fbcbb
1 change: 1 addition & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def build_test_list():
"--parallelism.data_parallel_shard_degree=4",
"--activation_checkpoint.mode='full'",
"--model.use_flex_attn",
"--model.attn_mask_type='block_causal'",
]
],
"FSDP+FLEX_ATTN",
Expand Down
7 changes: 5 additions & 2 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,13 @@ def __init__(self):
self.parser.add_argument(
"--model.use_flex_attn",
action="store_true",
help="Whether to use Flex Attention.",
help="""
Whether to use Flex Attention.
Mixed usage of SDPA and FlexAttention is not upported yet.
""",
)
self.parser.add_argument(
"--model.attn_bias_type",
"--model.attn_mask_type",
type=str,
default="causal",
choices=["causal", "block_causal"],
Expand Down
131 changes: 63 additions & 68 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,39 @@
)


BatchBlockMaskType = tuple[Optional[int], BlockMask]


class FlexAttn(torch.nn.Module):
# We registered flex_attention related attributes as class variables as we
# need to amortize the cost of compilation. Enabling per-instance flex_attention
# is not supported.
block_mask: ClassVar[Optional[BlockMask]] = None
flex_attn: ClassVar[Optional[Callable]] = None
attn_bias_type: ClassVar[Optional[str]] = None
compiled_create_block_mask: ClassVar[Optional[Callable]] = None

def __init__(self, attn_bias_type: str) -> None:
# need to amortize the cost of compilation.
flex_attn: ClassVar[Callable] = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
used_attn_mask_types: ClassVar[set[str]] = set()
# Attention mask type to the created (id(batch), BlockMask).
# This allows us to keep track the created block masks for each
# new batch. We will use this to update the block mask when a
# new batch is created. This also allows user to create different
# block masks for different layers.
block_masks: ClassVar[dict[str, BatchBlockMaskType]] = {}

# Instance variables.
attn_mask_type: str

def __init__(self, attn_mask_type: str) -> None:
super().__init__()
if FlexAttn.attn_bias_type is not None:
assert (
FlexAttn.attn_bias_type == attn_bias_type
), "All FlexAttention must have the same configurations."
else:
if attn_bias_type not in ["causal", "block_causal"]:
raise ValueError(f"Unrecognized attn_bias_type {attn_bias_type}.")
FlexAttn.attn_bias_type = attn_bias_type
if attn_mask_type not in ["causal", "block_causal"]:
raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.")
self.attn_mask_type = attn_mask_type
FlexAttn.used_attn_mask_types.add(attn_mask_type)

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
assert FlexAttn.block_mask is not None
assert FlexAttn.flex_attn is not None
return FlexAttn.flex_attn(q, k, v, block_mask=FlexAttn.block_mask)
block_mask = FlexAttn.block_masks[self.attn_mask_type][1]
return FlexAttn.flex_attn(q, k, v, block_mask=block_mask)

@classmethod
def _get_causal_mask_fn(cls) -> Callable:
Expand All @@ -66,34 +73,41 @@ def block_causal_mask(b, h, q_idx, kv_idx):

@classmethod
@torch.no_grad()
def init_attention_bias(
def init_attention_mask(
cls, batch: torch.Tensor, eos_id: Optional[int] = None
) -> None:
if cls.block_mask is not None and cls.attn_bias_type == "causal":
# We don't need to create another block mask for causal masking if existed.
return

match cls.attn_bias_type:
case "causal":
mask_fn = cls._get_causal_mask_fn()
case "block_causal":
mask_fn = cls._get_block_causal_mask_fn(batch, eos_id)
case _:
raise RuntimeError(f"Shouldn't reach here. {cls.attn_bias_type}")

seq_len = batch.shape[1]
if cls.compiled_create_block_mask is None:
cls.compiled_create_block_mask = torch.compile(create_block_mask)
cls.block_mask = cls.compiled_create_block_mask(
mask_fn, None, None, seq_len, seq_len
)
cls.flex_attn = torch.compile(flex_attention, mode="max-autotune-no-cudagraphs")


class SDPA(torch.nn.Module):
def __init__(self, attn_bias_type: str) -> None:
for attn_mask_type in cls.used_attn_mask_types:
block_mask = cls.block_masks.get(attn_mask_type, None)
if block_mask is not None:
batch_id = block_mask[0]
if batch_id is None or batch_id == id(batch):
continue

match attn_mask_type:
case "causal":
batch_id = None
mask_fn = cls._get_causal_mask_fn()
case "block_causal":
batch_id = id(batch)
if eos_id is None:
raise RuntimeError(
"eos_id must be provided for block_causal mask."
)
mask_fn = cls._get_block_causal_mask_fn(batch, eos_id)
case _:
raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}")

seq_len = batch.shape[1]
block_mask = cls.compiled_create_block_mask(
mask_fn, None, None, seq_len, seq_len
)
cls.block_masks[attn_mask_type] = (batch_id, block_mask)


class ScaledDotProductAttention(torch.nn.Module):
def __init__(self, attn_mask_type: str) -> None:
super().__init__()
if attn_bias_type != "causal":
if attn_mask_type != "causal":
raise ValueError(
"TorchTitan with SDPA currently only supports causal mask."
)
Expand All @@ -103,32 +117,13 @@ def forward(
) -> torch.Tensor:
return F.scaled_dot_product_attention(q, k, v, is_causal=True)

@classmethod
@torch.no_grad()
def init_attention_bias(
cls,
batch: torch.Tensor,
eos_id: Optional[int] = None,
) -> None:
# For SDPA, we don't need to do anything.
return


_selected_attention = None


def build_attention(use_flex_attn: bool, attn_bias_type: str):
global _selected_attention
def build_attention(use_flex_attn: bool, attn_mask_type: str):
if use_flex_attn:
assert _selected_attention is None or _selected_attention == FlexAttn
_selected_attention = FlexAttn
return FlexAttn(attn_bias_type)
return FlexAttn(attn_mask_type)
else:
assert _selected_attention is None or _selected_attention == SDPA
_selected_attention = SDPA
return SDPA(attn_bias_type)
return SDPA(attn_mask_type)


def init_attention_bias(batch: torch.Tensor, eos_id: Optional[int] = None) -> None:
global _selected_attention
_selected_attention.init_attention_bias(batch, eos_id)
def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None:
FlexAttn.init_attention_mask(batch, eos_id)
10 changes: 5 additions & 5 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from torchtitan.components.tokenizer import Tokenizer
from torchtitan.config_manager import JobConfig
from torchtitan.models.attention import build_attention, init_attention_bias
from torchtitan.models.attention import build_attention, init_attention_mask
from torchtitan.models.norms import build_norm
from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol

Expand All @@ -40,15 +40,15 @@ class TransformerModelArgs(BaseModelArgs):
norm_type: str = "rmsnorm"

use_flex_attn: bool = False
attn_bias_type: str = "causal"
attn_mask_type: str = "causal"
eos_id: int = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? I think the model should not have to know about properties of the tokenizer like this

Copy link
Contributor Author

@fegin fegin Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an issue? It is configurable and is given by tokenizer, not defined by the model. Otherwise, how will the attention module separate the tokens?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers
I feel it would be one way or another -- we either put eos_id here, or each iteration would take a mask_mod function as model input. It is not clear to me which is cleaner -- I feel current one is not bad. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion of which approach is better. Both are viable approaches. I just want to emphasize that the model has to know eos_id, whether it is saved as an instance variable or is passed as an argument through forward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin maybe I am misunderstanding something here then. It seems to me that most of the logic for constructing the block mask should be handled by the dataloader. You are already iterating over a bunch of samples that you want to pack, right? Why not just use this opportunity to construct a list of sequence lengths in the pack? Then this can be used to construct a BlockMask for flex without the model needing to know anything about eos_id. @tianyu-l iiuc this is kinda similar to what you're proposing, but passing BlockMask instead of the mask_mod function

Copy link
Contributor

@tianyu-l tianyu-l Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin I also didn't get this

I just want to emphasize that the model has to know eos_id, whether it is saved as an instance variable or is passed as an argument through forward.

An alternative seems to pass the block_mask_fn (say from data loader, or an util function) to model forward. Technically, in this case, the model doesn't know eos_id.

Copy link
Contributor Author

@fegin fegin Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers, @tianyu-l

maybe I am misunderstanding something here then. It seems to me that most of the logic for constructing the block mask should be handled by the dataloader.

This is debatable. If you are building the block causal mask inside dataloader, you are polluting the dataloader. There are options to use either Flex or Flash and there can be more than just simple block causal. Why does dataloader need to know whether the model is using Flex or Flash or what kind of masking does the model use? It is just that which component you choose to know the information of the other component.

We discussed this internally about whether to couple building mask with dataloader, there was an opinion internally to decouple dataloader from the attention implementation as researchers can do different attention masks without changing dataloader but just knowing EOS_ID is enough. I would prefer to keep that decision even though that was not the decision specifically for TorchTitan but the discussion was how to do CP + Flex for PTD.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one more quick point, this is our generic Document mask re-writer: https://github.com/pytorch-labs/attention-gym/blob/001b36d625aceae8a03f59241113e4797122db1d/attn_gym/masks/document_mask.py#L33. It takes in a two things, another mask_mod and a tensor specifying the boundaries. I find this de-coupiling pretty attractive where you decouple the generation of the extra metadata and which mask mod you want to perform, in this concrete case it is "causal". In this case the choice is "causal" but someone has to own setting up your inpt tensor in the packed format and the generation of the extra metadata likely should be colocated IMO


def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
self.norm_type = job_config.model.norm_type
self.vocab_size = tokenizer.n_words
self.max_seq_len = job_config.training.seq_len
self.use_flex_attn = job_config.model.use_flex_attn
self.attn_bias_type = job_config.model.attn_bias_type
self.attn_mask_type = job_config.model.attn_mask_type

def get_num_flop_per_token(self, num_params: int, seq_len: int) -> int:
l, h, q, t = (
Expand Down Expand Up @@ -196,7 +196,7 @@ def __init__(self, model_args: TransformerModelArgs):
self.wo = nn.Linear(
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)
self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_bias_type)
self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type)

def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
Expand Down Expand Up @@ -471,7 +471,7 @@ def forward(self, tokens: torch.Tensor):
torch.Tensor: Output logits after applying the Transformer model.

"""
init_attention_bias(tokens, eos_id=self.eos_id)
init_attention_mask(tokens, eos_id=self.eos_id)

# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
Expand Down
Loading