Skip to content

Incorrect transpose in Qwen3 varlen attention backend #2223

@francesco-bertolotti

Description

@francesco-bertolotti

incorrect transpose in Qwen3 varlen attention backend

While reviewing the Qwen3 implementation, I noticed that the attention output is always transposed at the end, regardless of which attention backend is used.

For the flex and sdpa backends, this final transpose appears correct. However, for the varlen backend, it seems to introduce an extra transpose that swaps the head dimension with the head size dimension. This is then masked by a subsequent .view() that reshapes the tensor back to the expected dimensions, effectively hiding the issue.

Relevant code:

  • Qwen3 model output handling:
    match self.attn_type:
    case "flex":
    assert isinstance(attention_masks, BlockMask), attention_masks
    output = self.inner_attention(
    xq, xk, xv, block_mask=attention_masks, scale=self.scaling
    )
    case "varlen":
    # TODO: pass self.scaling into varlen attention
    assert isinstance(attention_masks, VarlenMetadata), attention_masks
    output = self.inner_attention(
    xq,
    xk,
    xv,
    self.head_dim,
    attention_masks,
    scale=self.scaling,
    )
    case "sdpa":
    assert attention_masks is None
    output = self.inner_attention(xq, xk, xv, scale=self.scaling)
    case _:
    raise ValueError(f"Unknown attention type: {self.attn_type}")
    output = output.transpose(
    1, 2
    ).contiguous() # (bs, seqlen, n_local_heads, head_dim)
    output = output.view(bs, seqlen, -1)
  • Attention backend wrappers:
    n_local_heads = xq.shape[1]
    # pyrefly: ignore [no-matching-overload]
    xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
    # pyrefly: ignore [no-matching-overload]
    xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
    # pyrefly: ignore [no-matching-overload]
    xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
    return VarlenAttentionWrapper._compiled_varlen_attn(
    xq_packed,
    xk_packed,
    xv_packed,
    cu_seq_q,
    cu_seq_k,
    max_q,
    max_k,
    is_causal=True,
    scale=scale,
    )

The root cause seems to be that the varlen backend already transposes the head and head_size dimensions before calling into the backend so that Q/K/V are in the expected layout. As a result, the output from varlen should not be transposed again. In contrast, the flex and sdpa wrappers do not perform this pre-transposition, so their outputs do require the final transpose.

A straightforward fix would be to move the final output.transpose(...) inside the match cases for sdpa and flex, but exclude it from the varlen case.


If you can confirm that my understanding is correct, I’m happy to submit a quick PR to address this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions