Skip to content

[data, model, training] fix: Stabilize Qwen3-VL packed SFT#3869

Open
wplf wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
wplf:codex/qwen3vl-thd-moe-fixes
Open

[data, model, training] fix: Stabilize Qwen3-VL packed SFT#3869
wplf wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
wplf:codex/qwen3vl-thd-moe-fixes

Conversation

@wplf
Copy link
Copy Markdown
Contributor

@wplf wplf commented May 18, 2026

What does this PR do ?

Fixes several Qwen3-VL packed SFT issues that can corrupt MRoPE position ids, route padding tokens through MoE experts, or swap dist-train vision/deepstack tensors. These issues are most visible with packed THD training, mixed image/text-only VLM batches, MoE, and dist-train vision-module splitting.

Bugfix Summary

  • Packed THD MRoPE: compute MRoPE from the padded packed sequence layout, temporarily unpack THD input ids to BSHD when needed, then repack position ids back to THD. This avoids position-id/layout mismatch for packed Qwen3-VL forward.
  • MoE padding mask: preserve the dataset 2D padding mask through Qwen3-VL batch packing, model forward, text model, and decoder block calls. MoE now receives the real padding mask first and only falls back to lm_input_ids.eq(0) when no mask exists.
  • Qwen VLM collate mask preservation: when a batch mixes image examples and text-only examples, keep and merge each processor attention_mask alongside input_ids. This lets packed SFT derive padding_mask and moe_padding_mask from real padding metadata instead of token id heuristics.
  • Dist-train vision payload order: pack deepstack feature chunks first and final vision_embeds last, matching set_dist_train_input_tensors, which unpacks chunks[:-1] as deepstack and chunks[-1] as final vision embeddings.
  • Pure-text/deepstack handling: allow deepstack_visual_embeds=None in split_deepstack_embs so pure-text microbatches do not crash.

Validation Findings

THD MRoPE probe, micro-batch size 3

I added a tiny local probe that feeds already-packed THD input_ids with shape [1, 18], micro_batch_size=3, seq_len=6, and cu_seqlens=[0, 6, 12, 18]. The probe stubs the language model and inspects the exact position_ids passed from Qwen3VLModel.forward into the LM.

With this PR:

rope input shape seen by get_rope_index: [3, 6]
position_ids: 0 1 2 3 4 5 | 0 1 2 3 4 5 | 0 1 2 3 4 5

Without this PR, at base 39b79eb78:

rope input shape seen by get_rope_index: [1, 18]
position_ids: 0 1 2 3 4 5 | 6 7 8 9 10 11 | 12 13 14 15 16 17

So the fix makes packed THD MRoPE reset per packed sample using cu_seqlens; without it, the flat THD row receives one continuous RoPE index across the whole microbatch.

Zero-LR MoE BSHD/THD parity check

I also ran a fake-small Qwen3-VL MoE model with padded input, micro_batch_size=2, 4 experts, top-2 routing, and lr=0, comparing BSHD against THD while holding weights fixed.

steps: 20
max_abs_loss_diff: 0.0
max_abs_forward_diff: 0.0
max_active_abs_forward_diff: 0.0
one-step global_max_abs_grad_diff: 2.2351741790771484e-08

The tiny gradient delta is float32 roundoff scale from different layout/reduction order, while forward output and loss are exactly aligned for the same padded tokens and MoE padding mask.

Tests

  • Added focused unit coverage for packed THD MRoPE, explicit packed THD MoE padding-mask propagation, Qwen collate attention-mask preservation, dist-train vision payload order, pure-text deepstack handling, and text-model padding-mask passthrough.

Local checks run:

  • uv run pre-commit run --all-files
  • uv --project /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_mcore/users/jinliangl/repos/Megatron-Bridge run python -m pytest /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_mcore/users/jinliangl/repos/Megatron-Bridge/tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_model.py -k "packed_thd" -q
  • uv --project /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_mcore/users/jinliangl/repos/Megatron-Bridge run python -m pytest /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_mcore/users/jinliangl/repos/Megatron-Bridge/tests/unit_tests/data/vlm_datasets/test_collate.py -k "qwen2_5_collate_fn" -q
  • uv --project /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_mcore/users/jinliangl/repos/Megatron-Bridge run python -m pytest /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_mcore/users/jinliangl/repos/Megatron-Bridge/tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_text_model_forward.py -q
  • uv --project /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_mcore/users/jinliangl/repos/Megatron-Bridge run ruff check /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_mcore/users/jinliangl/repos/Megatron-Bridge/tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_model.py /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_mcore/users/jinliangl/repos/Megatron-Bridge/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py

Note: pytest was run from /tmp because this checkout already has nemo_experiments/ in the repo root and the global test fixture refuses to run when that directory exists.

GitHub Actions CI

External branch from wplf fork. An NVIDIA developer may need to trigger CI with /ok to test 0b9d864f if copy-pr-bot does not start automatically.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation? Not needed; code and tests only.
  • Does the PR affect components that are optional to install? No.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 18, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@wplf wplf added area:data Dataset builders, preprocessing, and samplers area:model Model implementations and HF bridge logic area:training Training loop, callbacks, and runtime integration bug Something isn't working external model-qwen needs-review PR is ready for code review and waiting on a reviewer t-seqpacking labels May 18, 2026
@wplf wplf force-pushed the codex/qwen3vl-thd-moe-fixes branch 3 times, most recently from bf8360c to 5e1fd46 Compare May 18, 2026 09:02
@wplf wplf changed the title [model, data] fix: Stabilize Qwen3-VL packed SFT [model, training] fix: Stabilize Qwen3-VL packed SFT May 18, 2026
@wplf wplf removed the area:data Dataset builders, preprocessing, and samplers label May 18, 2026
@wplf wplf force-pushed the codex/qwen3vl-thd-moe-fixes branch from 5e1fd46 to b6d035d Compare May 18, 2026 09:15
@wplf wplf changed the title [model, training] fix: Stabilize Qwen3-VL packed SFT [data, model, training] fix: Stabilize Qwen3-VL packed SFT May 18, 2026
@wplf wplf added the area:data Dataset builders, preprocessing, and samplers label May 18, 2026
@wplf wplf force-pushed the codex/qwen3vl-thd-moe-fixes branch 2 times, most recently from 122b751 to 9823be7 Compare May 18, 2026 09:41
Comment thread src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py
@wplf
Copy link
Copy Markdown
Contributor Author

wplf commented May 18, 2026

/ok to test 9823be7

@wplf wplf force-pushed the codex/qwen3vl-thd-moe-fixes branch 2 times, most recently from 256bc87 to a74299e Compare May 18, 2026 16:26
@wplf
Copy link
Copy Markdown
Contributor Author

wplf commented May 18, 2026

/ok to test a74299e

@yaoyu-33
Copy link
Copy Markdown
Contributor

/claude review

Comment thread tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_model.py
@claude
Copy link
Copy Markdown
Contributor

claude Bot commented May 18, 2026

Light Review — PR 3869

Clean, well-motivated bugfix PR. Five distinct issues addressed with matching tests. A few observations:

1. Test gap — explicit moe_padding_mask propagation: The new test_forward_packed_thd_rope_uses_padded_cu_seqlens test covers the packed THD MRoPE path thoroughly, but does not pass an explicit moe_padding_mask. This means only the fallback path (lm_input_ids.eq(0) at model.py:672) is exercised in unit tests. The primary path — where the collate-derived moe_padding_mask is passed through preprocess_packed_seqs in two code branches (model.py lines 536-542 and 598-604) — lacks unit coverage. A companion test that supplies moe_padding_mask and asserts it arrives at language_model.last_kwargs[padding_mask] correctly after THD packing would close this gap.

2. Dist-train pack order fix is well-tested: The test_pack_dist_train_vision_module_output_matches_unpack_contract test clearly validates the deepstack-first, vision-last ordering contract. The existing test_pack_dist_train_vision_module_output assertion update is consistent.

3. Collate attention_mask preservation: The test_qwen2_5_collate_fn_preserves_attention_mask_for_mixed_image_text_batch test covers the mixed image/text merge path. The pad_to refactor to accept a pad_value argument is clean.

4. preproc_output[:6] index coupling (text_model.py:162): This assumes the base class _preprocess returns padding_mask at index 5. The test covers this, but this is a fragile coupling to the upstream Megatron-Core tuple layout — worth a defensive comment if the upstream contract is not stable.

Suggested test cases: No perf tests impacted.

Signed-off-by: jinliangl <jinliangl@nvidia.com>
@wplf wplf force-pushed the codex/qwen3vl-thd-moe-fixes branch from a74299e to 0b9d864 Compare May 19, 2026 03:04
@wplf
Copy link
Copy Markdown
Contributor Author

wplf commented May 19, 2026

/ok to test 0b9d864

@wplf
Copy link
Copy Markdown
Contributor Author

wplf commented May 19, 2026

/claude review

Comment thread src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py
Comment thread src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py
@claude
Copy link
Copy Markdown
Contributor

claude Bot commented May 19, 2026


Review Summary

Solid bugfix PR with good test coverage across all five issues. The fixes are well-scoped and the validation methodology (THD MRoPE probe + zero-LR parity check) is thorough.

Observations

  1. _preprocess index 5 contract (text_model.py:146-162): The code assumes MCore GPTModel._preprocess returns padding_mask at tuple index 5. The unit test mocks _preprocess, so it validates the Qwen3VL reading logic but will not catch an upstream MCore layout change. Consider a guard or pinned-version comment.

  2. lm_input_ids.eq(0) fallback (model.py:671-672): The fallback heuristic for MoE padding mask treats token id 0 as padding. This is fine as a last resort when no explicit mask is available, but worth a brief comment so future readers know it is intentional and approximate.

Both are minor robustness notes, not blockers.

Suggested test cases

No perf tests impacted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:data Dataset builders, preprocessing, and samplers area:model Model implementations and HF bridge logic area:training Training loop, callbacks, and runtime integration bug Something isn't working model-qwen needs-review PR is ready for code review and waiting on a reviewer t-seqpacking

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants