-
Notifications
You must be signed in to change notification settings - Fork 725
Description
incorrect transpose in Qwen3 varlen attention backend
While reviewing the Qwen3 implementation, I noticed that the attention output is always transposed at the end, regardless of which attention backend is used.
For the flex and sdpa backends, this final transpose appears correct. However, for the varlen backend, it seems to introduce an extra transpose that swaps the head dimension with the head size dimension. This is then masked by a subsequent .view() that reshapes the tensor back to the expected dimensions, effectively hiding the issue.
Relevant code:
- Qwen3 model output handling:
torchtitan/torchtitan/models/qwen3/model/model.py
Lines 271 to 298 in 9f211ec
match self.attn_type: case "flex": assert isinstance(attention_masks, BlockMask), attention_masks output = self.inner_attention( xq, xk, xv, block_mask=attention_masks, scale=self.scaling ) case "varlen": # TODO: pass self.scaling into varlen attention assert isinstance(attention_masks, VarlenMetadata), attention_masks output = self.inner_attention( xq, xk, xv, self.head_dim, attention_masks, scale=self.scaling, ) case "sdpa": assert attention_masks is None output = self.inner_attention(xq, xk, xv, scale=self.scaling) case _: raise ValueError(f"Unknown attention type: {self.attn_type}") output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim) output = output.view(bs, seqlen, -1) - Attention backend wrappers:
torchtitan/torchtitan/models/attention.py
Lines 70 to 88 in 9f211ec
n_local_heads = xq.shape[1] # pyrefly: ignore [no-matching-overload] xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim) # pyrefly: ignore [no-matching-overload] xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim) # pyrefly: ignore [no-matching-overload] xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim) return VarlenAttentionWrapper._compiled_varlen_attn( xq_packed, xk_packed, xv_packed, cu_seq_q, cu_seq_k, max_q, max_k, is_causal=True, scale=scale, )
The root cause seems to be that the varlen backend already transposes the head and head_size dimensions before calling into the backend so that Q/K/V are in the expected layout. As a result, the output from varlen should not be transposed again. In contrast, the flex and sdpa wrappers do not perform this pre-transposition, so their outputs do require the final transpose.
A straightforward fix would be to move the final output.transpose(...) inside the match cases for sdpa and flex, but exclude it from the varlen case.
If you can confirm that my understanding is correct, I’m happy to submit a quick PR to address this.