Skip to content

Efficiency Issue with Qwen3 attention module #2224

@francesco-bertolotti

Description

@francesco-bertolotti

Avoid unnecessary KV head repetition for GQA when using varlen (flash-attn) backend

This issue is related to #2223.

To support GQA with the sdpa and flex attention backends, the attention module correctly repeats key and value heads to match the number of query heads, since neither SDPA nor FLEX natively supports GQA.

However, the varlen backend is backed by [flash-attn](https://github.com/Dao-AILab/flash-attention), which does support GQA natively. In this case, repeating keys and values is unnecessary and can significantly increase memory usage.

Relevant code:

  • KV repetition in Qwen3 attention:
    # Apply rotary embedding
    xq, xk = apply_rotary_emb(xq, xk, rope_cache, positions)
    # repeat k/v heads if n_kv_heads < n_heads
    keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
    values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
    xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
    xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
    xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

Proposed fix

The fix should be straightforward:

  1. Move the .repeat() of keys and values into the match cases for the sdpa and flex backends, and explicitly exclude the varlen case.

  2. Slightly modify the varlen wrapper to distinguish between:

    • n_local_heads for queries, and
    • n_kv_local_heads for keys and values.

    Optionally, add an assertion enforcing that n_local_heads is a multiple of n_kv_local_heads, as required by flash-attn.

Relevant wrapper code:

  • n_local_heads = xq.shape[1]
    # pyrefly: ignore [no-matching-overload]
    xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
    # pyrefly: ignore [no-matching-overload]
    xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
    # pyrefly: ignore [no-matching-overload]
    xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim)

Optional follow-up

If a slightly breaking change is acceptable, we could also remove the transposes currently performed inside the varlen wrapper, which would eliminate the need for an additional transpose before calling into the varlen backend.

If this proposal sounds reasonable, I can submit a PR that addresses this issue and also resolves #2223.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions