[model, perf] feat: real THD packing in qwen3_vl_step#3838
Open
cuichenx wants to merge 1 commit into
Open
Conversation
Replace the in-forward layout-shuffle that emitted uniform-length cu_seqlens
with a true variable-length THD packer. Real per-sample content is now
concatenated into a single (1, total_padded) THD row and exposed to flash-attn
via cu_seqlens_unpadded (real lengths) distinct from cu_seqlens_padded (kernel
stride), so the attention kernel can skip the pad tail inside each segment.
Adds three pieces of plumbing alongside the existing pack_or_pad_batch_sequences
fallback:
* _per_sample_real_lengths(): reduces a 2D or 4D attention_mask to (B,) int32.
* _pack_bshd_to_thd(): builds the packed input + cu_seqlens_(un)padded +
moe_padding_mask from real lengths.
* forward_step now captures the dataset's 2D attention_mask before
pad_or_truncate_attn_to_len strips it (the existing util only accepts 4D
masks), then drives packing from those lengths. rope_cu_seqlens is passed
so the model's per-sub-seq MRoPE path restarts position ids at segment
boundaries; moe_padding_mask keeps align-pad out of the router stats.
For HF-dataset users this is gated by both
dataset.pack_sequences_in_batch=True
dataset.skip_getting_attention_mask_from_dataset=False
so the real attention_mask reaches forward_step. Without the mask the code
falls back to "last non-zero token" detection from the BSHD pad, which works
when the collator does not insert a non-zero tokenizer.pad_token_id mid-row.
Empirically on Qwen3.5-VL 2B SFT (cord_v2, 1 node × 8 H100, seq_length=2048,
mcore main @ d167123d incl. NVIDIA/Megatron-LM#2645 GDN packed-seq):
| config | PACK=0 (BSHD) | PACK=1 (THD) |
|-------------------------------------------|---------------|--------------|
| MBS=2 GBS=16, force-pad to seq_length | 363 ms/iter | 342 ms/iter |
| MBS=4 GBS=32, force-pad to seq_length | OOM | 505 ms/iter |
Loss curves match within bf16 noise. Per-iter speedup at fixed MBS is modest
on this hybrid GDN model (~6%) because attention is only ~12-15% of step time;
the larger win is that THD unlocks micro-batch sizes that BSHD cannot fit under
PP>1 / EP>1 force-padding, yielding +44% samples/sec at MBS=4 vs the best
BSHD-fit configuration.
Requires upstream mcore changes from #3323 for the
model.forward rope_cu_seqlens / moe_padding_mask kwargs and the
NVIDIA/Megatron-LM#2645 GDN-cu_seqlens support to land first.
Signed-off-by: Chen Cui <chcui@nvidia.com>
Contributor
Author
|
/claude review |
|
Great work, may I ask why the original solution used the BSHD layout in Qwen3 VL training. Is there any difference in the Qwen3 VL model? |
Contributor
Author
qwen3 vl step was created for VERL to use. VERL required the input to the forward pass to be in BSHD format so the packing was done inside the forward pass |
Contributor
|
Let's also figure out if HybridEP can be used after removing the padding. |
zhongbozhu
reviewed
May 16, 2026
| raise ValueError(f"Unsupported attention_mask rank: {attention_mask.dim()}") | ||
|
|
||
|
|
||
| def _pack_bshd_to_thd( |
Contributor
There was a problem hiding this comment.
maybe we should make this optimization optional? since it affects the feasibility of HybridEP for moe models.
Contributor
|
Will CP>1 be supported with this PR, looks like it's gonna have some impact on CP. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Make
dataset.pack_sequences_in_batch=Trueactually skip pad-attention FLOPs on theqwen3_vl_steppath (HF datasets such as cord_v2). The current implementation onlyrelayouts BSHD into a flat row and emits
cu_seqlens = [0, T, 2T, …, BT]— everysegment looks fully real to flash-attn / GDN, so attention still computes the full
padded tile per row.
This PR rewires the packing branch in
forward_stepto:(B, T)attention maskbefore
pad_or_truncate_attn_to_lenstrips it (that utility only accepts 4DMegatron-style masks).
(1, total_padded)THD row.cu_seqlens_q/kv(real boundaries) distinct fromcu_seqlens_q/kv_padded(kernel stride into the flat tensor) — this is what lets the attention kernel skip
the pad tail inside each segment.
rope_cu_seqlensandmoe_padding_maskso the model's per-sub-seq MRoPEresets at segment boundaries and the MoE router doesn't count align-pad as real.
Two new helpers, no changes to public APIs or other models.
Empirical results
Qwen3.5-VL 2B SFT on cord_v2, 1 node × 8 H100, seq_length=2048, mcore main @ d167123d
(includes NVIDIA/Megatron-LM#2645 GDN packed-seq support).
Loss curves match within bf16 noise. The bigger win is that THD enables micro-batch
sizes that BSHD cannot fit under PP>1 / EP>1 force-padding — at MBS=4 we get +44%
samples/sec over the best BSHD-fit baseline.
Dependencies — DO NOT MERGE BEFORE
rope_cu_seqlens/moe_padding_maskkwargs to
Qwen3VLModel.forwardand the matching plumbing throughtext_model.py/transformer_block.py. Without it those two kwargs would beswallowed by
**kwargsand MRoPE / MoE-router would silently lose per-segmentsemantics.
gated_delta_net. Qwen3.5-VLis a hybrid attention/GDN architecture; before [DFM migration] Wan features verifications #2645 GDN explicitly raised
NotImplementedError(\"GDN does not support packed sequence for now.\").users to opt in via
pack_sequences_in_batch=True. This PR is a no-op when thatflag is left at its default
False.Usage
To exercise the packing path on HF VLM datasets the user must also flip
dataset.skip_getting_attention_mask_from_dataset=Falseso the real attention maskreaches
forward_step. The recipe-side default for that flag will be addressedseparately. A fallback path that infers real lengths from the trailing zero-pad is
provided for completeness, but only matches in the (uncommon) case where the
collator's
tokenizer.pad_token_idis 0.Test plan
dataset.pack_sequences_in_batch=True+dataset.skip_getting_attention_mask_from_dataset=False(requires [main] qwen-vl THD packed-sequence support and fixes #3323merged for full correctness).
🤖 Generated with Claude Code