Skip to content

[model, perf] feat: real THD packing in qwen3_vl_step#3838

Open
cuichenx wants to merge 1 commit into
mainfrom
chcui/qwen3vl-step-thd-packing
Open

[model, perf] feat: real THD packing in qwen3_vl_step#3838
cuichenx wants to merge 1 commit into
mainfrom
chcui/qwen3vl-step-thd-packing

Conversation

@cuichenx
Copy link
Copy Markdown
Contributor

Summary

Make dataset.pack_sequences_in_batch=True actually skip pad-attention FLOPs on the
qwen3_vl_step path (HF datasets such as cord_v2). The current implementation only
relayouts BSHD into a flat row and emits cu_seqlens = [0, T, 2T, …, BT] — every
segment looks fully real to flash-attn / GDN, so attention still computes the full
padded tile per row.

This PR rewires the packing branch in forward_step to:

  1. Capture real per-sample lengths from the dataset's 2D (B, T) attention mask
    before pad_or_truncate_attn_to_len strips it (that utility only accepts 4D
    Megatron-style masks).
  2. Concatenate per-sample real content into a single (1, total_padded) THD row.
  3. Emit cu_seqlens_q/kv (real boundaries) distinct from cu_seqlens_q/kv_padded
    (kernel stride into the flat tensor) — this is what lets the attention kernel skip
    the pad tail inside each segment.
  4. Pass rope_cu_seqlens and moe_padding_mask so the model's per-sub-seq MRoPE
    resets at segment boundaries and the MoE router doesn't count align-pad as real.

Two new helpers, no changes to public APIs or other models.

Empirical results

Qwen3.5-VL 2B SFT on cord_v2, 1 node × 8 H100, seq_length=2048, mcore main @ d167123d
(includes NVIDIA/Megatron-LM#2645 GDN packed-seq support).

config PACK=0 (BSHD) PACK=1 (THD)
MBS=2 GBS=16, force-pad to seq_length 363 ms / iter 342 ms / iter
MBS=4 GBS=32, force-pad to seq_length OOM 505 ms / iter

Loss curves match within bf16 noise. The bigger win is that THD enables micro-batch
sizes that BSHD cannot fit under PP>1 / EP>1 force-padding — at MBS=4 we get +44%
samples/sec over the best BSHD-fit baseline.

Dependencies — DO NOT MERGE BEFORE

Usage

To exercise the packing path on HF VLM datasets the user must also flip
dataset.skip_getting_attention_mask_from_dataset=False so the real attention mask
reaches forward_step. The recipe-side default for that flag will be addressed
separately. A fallback path that infers real lengths from the trailing zero-pad is
provided for completeness, but only matches in the (uncommon) case where the
collator's tokenizer.pad_token_id is 0.

Test plan

🤖 Generated with Claude Code

Replace the in-forward layout-shuffle that emitted uniform-length cu_seqlens
with a true variable-length THD packer.  Real per-sample content is now
concatenated into a single (1, total_padded) THD row and exposed to flash-attn
via cu_seqlens_unpadded (real lengths) distinct from cu_seqlens_padded (kernel
stride), so the attention kernel can skip the pad tail inside each segment.

Adds three pieces of plumbing alongside the existing pack_or_pad_batch_sequences
fallback:

  * _per_sample_real_lengths(): reduces a 2D or 4D attention_mask to (B,) int32.
  * _pack_bshd_to_thd(): builds the packed input + cu_seqlens_(un)padded +
    moe_padding_mask from real lengths.
  * forward_step now captures the dataset's 2D attention_mask before
    pad_or_truncate_attn_to_len strips it (the existing util only accepts 4D
    masks), then drives packing from those lengths.  rope_cu_seqlens is passed
    so the model's per-sub-seq MRoPE path restarts position ids at segment
    boundaries; moe_padding_mask keeps align-pad out of the router stats.

For HF-dataset users this is gated by both
  dataset.pack_sequences_in_batch=True
  dataset.skip_getting_attention_mask_from_dataset=False
so the real attention_mask reaches forward_step.  Without the mask the code
falls back to "last non-zero token" detection from the BSHD pad, which works
when the collator does not insert a non-zero tokenizer.pad_token_id mid-row.

Empirically on Qwen3.5-VL 2B SFT (cord_v2, 1 node × 8 H100, seq_length=2048,
mcore main @ d167123d incl. NVIDIA/Megatron-LM#2645 GDN packed-seq):

  | config                                    | PACK=0 (BSHD) | PACK=1 (THD) |
  |-------------------------------------------|---------------|--------------|
  | MBS=2 GBS=16, force-pad to seq_length     | 363 ms/iter   | 342 ms/iter  |
  | MBS=4 GBS=32, force-pad to seq_length     | OOM           | 505 ms/iter  |

Loss curves match within bf16 noise.  Per-iter speedup at fixed MBS is modest
on this hybrid GDN model (~6%) because attention is only ~12-15% of step time;
the larger win is that THD unlocks micro-batch sizes that BSHD cannot fit under
PP>1 / EP>1 force-padding, yielding +44% samples/sec at MBS=4 vs the best
BSHD-fit configuration.

Requires upstream mcore changes from #3323 for the
model.forward rope_cu_seqlens / moe_padding_mask kwargs and the
NVIDIA/Megatron-LM#2645 GDN-cu_seqlens support to land first.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx
Copy link
Copy Markdown
Contributor Author

/claude review

@yaoyu-33 yaoyu-33 added area:model Model implementations and HF bridge logic blocked Work cannot move forward until an external dependency is cleared feature New capabilities, enhancements, or enablement work needs-more-tests Requires additional L0 and L1 test coverage before merge labels May 15, 2026
@liangxuZhang
Copy link
Copy Markdown

Great work, may I ask why the original solution used the BSHD layout in Qwen3 VL training. Is there any difference in the Qwen3 VL model?

@cuichenx
Copy link
Copy Markdown
Contributor Author

Great work, may I ask why the original solution used the BSHD layout in Qwen3 VL training. Is there any difference in the Qwen3 VL model?

qwen3 vl step was created for VERL to use. VERL required the input to the forward pass to be in BSHD format so the packing was done inside the forward pass

@zhongbozhu
Copy link
Copy Markdown
Contributor

Let's also figure out if HybridEP can be used after removing the padding.

raise ValueError(f"Unsupported attention_mask rank: {attention_mask.dim()}")


def _pack_bshd_to_thd(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should make this optimization optional? since it affects the feasibility of HybridEP for moe models.

@zhongbozhu
Copy link
Copy Markdown
Contributor

Will CP>1 be supported with this PR, looks like it's gonna have some impact on CP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic blocked Work cannot move forward until an external dependency is cleared feature New capabilities, enhancements, or enablement work needs-more-tests Requires additional L0 and L1 test coverage before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants