Skip to content

[model] fix: Handle Qwen VL MTP with context parallelism#3895

Open
cuichenx wants to merge 2 commits into
mainfrom
kai/issue-3881-mtp-cp-qwen-vl
Open

[model] fix: Handle Qwen VL MTP with context parallelism#3895
cuichenx wants to merge 2 commits into
mainfrom
kai/issue-3881-mtp-cp-qwen-vl

Conversation

@cuichenx
Copy link
Copy Markdown
Contributor

@cuichenx cuichenx commented May 20, 2026

Summary

  • CP-localize Qwen VL MTP input_ids and position_ids before MCore postprocess when context parallelism is active.
  • Keep the existing sequence-parallel embedding scatter wrapper for MTP and guarantee cleanup with try/finally.
  • Add a CPU-only unit test that asserts MTP postprocess receives the CP-local zigzag token and position-id slices.

Fixes #3881.

Validation

  • Reproduced pre-fix on DFW interactive allocation 11920458, log /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_llm/users/chcui/logs/issue3881_mtp_cp_qwen_vl/repro_pre_fix_20260519_161753.log: ranks 0-3 failed in MultiTokenPredictionLayer._concat_embeddings with Expected size 128 but got size 64.
  • Validated post-fix on DFW interactive allocation 11920502, log /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_llm/users/chcui/logs/issue3881_mtp_cp_qwen_vl/repro_post_fix_20260519_162314.log: completed 0:0; iteration 1/1, lm loss: 1.354103E+01, mtp_1 loss: 6.454011E+00, grad norm: 7.894, number of nan iterations: 0.
  • git diff --check
  • python3 -m py_compile src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_text_model_forward.py
  • uv tool run ruff check tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_text_model_forward.py src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py
  • uv tool run ruff format --check tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_text_model_forward.py src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py

Notes

  • uv run pre-commit run --all-files and uv run --group test python -m pytest tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_text_model_forward.py -q could not start on the local host because the locked nvidia-resiliency-ext==0.6.0 wheel requires manylinux_2_39, while this host resolves as manylinux_2_31_x86_64.
  • System Python on this host does not have pytest installed, so the targeted pytest was not run outside uv either.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 20, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@cuichenx cuichenx added bug Something isn't working area:model Model implementations and HF bridge logic labels May 20, 2026
Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx cuichenx marked this pull request as ready for review May 20, 2026 00:12
@cuichenx
Copy link
Copy Markdown
Contributor Author

/ok to test e4dce6f

expected_position_ids = torch.cat([position_ids[..., 4:8], position_ids[..., 8:12]], dim=-1)
assert output == "ok"
assert torch.equal(dummy.postprocess_args["input_ids"], expected_input_ids)
assert torch.equal(dummy.postprocess_args["position_ids"], expected_position_ids)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: consider adding a small test for the position_ids=None edge case with CP > 1 and MTP enabled. The production code guards this explicitly (line 193: if postprocess_position_ids is not None), but there's no test exercising that branch — a regression there would pass silently.

Something like:

@pytest.mark.unit
def test_mtp_postprocess_with_none_position_ids():
    dummy = _DummyModel(mtp_process=True, cp_size=2, cp_rank=0)
    input_ids = torch.arange(16, dtype=torch.long).view(1, 16)
    attention_mask = torch.ones((1, 16), dtype=torch.long)

    output = Qwen3VLGPTModel.forward(
        dummy,
        input_ids=input_ids,
        position_ids=None,
        attention_mask=attention_mask,
    )

    assert output == "ok"
    assert dummy.postprocess_args["position_ids"] is None

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented May 20, 2026

Light Review

Clean fix. The CP-localization of input_ids and position_ids before MTP postprocess mirrors the zigzag pattern already used in model.py for combined embeddings (line 504-505), and the guard conditions (mtp_process and cp_size > 1 and packed_seq_params is None) are consistent. The try/finally for shadow embedding cleanup is a good hardening improvement.

One minor gap: the position_ids is None branch (text_model.py:193) is untested. See inline comment for a suggested test.

Suggested test cases:

  • No perf tests impacted.

@cuichenx cuichenx added the needs-more-tests Requires additional L0 and L1 test coverage before merge label May 20, 2026
@yaoyu-33 yaoyu-33 added the needs-review PR is ready for code review and waiting on a reviewer label May 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic bug Something isn't working needs-more-tests Requires additional L0 and L1 test coverage before merge needs-review PR is ready for code review and waiting on a reviewer

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[bug] MTP shadow embedding missing CP scatter causes shape mismatch when context_parallel_size > 1

2 participants