-
Notifications
You must be signed in to change notification settings - Fork 725
Description
Avoid unnecessary KV head repetition for GQA when using varlen (flash-attn) backend
This issue is related to #2223.
To support GQA with the sdpa and flex attention backends, the attention module correctly repeats key and value heads to match the number of query heads, since neither SDPA nor FLEX natively supports GQA.
However, the varlen backend is backed by [flash-attn](https://github.com/Dao-AILab/flash-attention), which does support GQA natively. In this case, repeating keys and values is unnecessary and can significantly increase memory usage.
Relevant code:
- KV repetition in Qwen3 attention:
torchtitan/torchtitan/models/qwen3/model/model.py
Lines 260 to 270 in 9f211ec
# Apply rotary embedding xq, xk = apply_rotary_emb(xq, xk, rope_cache, positions) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
Proposed fix
The fix should be straightforward:
-
Move the
.repeat()of keys and values into thematchcases for thesdpaandflexbackends, and explicitly exclude thevarlencase. -
Slightly modify the
varlenwrapper to distinguish between:n_local_headsfor queries, andn_kv_local_headsfor keys and values.
Optionally, add an assertion enforcing that
n_local_headsis a multiple ofn_kv_local_heads, as required byflash-attn.
Relevant wrapper code:
torchtitan/torchtitan/models/attention.py
Lines 70 to 76 in 9f211ec
n_local_heads = xq.shape[1] # pyrefly: ignore [no-matching-overload] xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim) # pyrefly: ignore [no-matching-overload] xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim) # pyrefly: ignore [no-matching-overload] xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
Optional follow-up
If a slightly breaking change is acceptable, we could also remove the transposes currently performed inside the varlen wrapper, which would eliminate the need for an additional transpose before calling into the varlen backend.
If this proposal sounds reasonable, I can submit a PR that addresses this issue and also resolves #2223.