Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[megatron] fix: VLMs using fused kernels
Currently, we will have error regarding to unexpected
keyword argument 'visual_pos_masks', this is because mbridge
did some customization of the `GPTModel` forward as well for
Qwen3VL to support deepstack:

https://github.com/ISEEKYAN/mbridge/blob/ecbdfbdfdc8027004702149d6dc87fbad7417708/mbridge/models/qwen3_vl/gpt_model.py#L84

Since mcore v0.13.0 introduced `_postprocess` and `_preprocess`,
and our patch focuses on `_postprocess`, I also cleaned up the function
for better maintainability and to fix this extra deepstack argument issue.
We can't simply patch `_postprocess` as we will need to
pass `temperature` argument as well:

```logs
output = self.forward_backward_batch(
/verl_megatron/verl/workers/actor/megatron_actor.py", line 598, in forward_backward_batch
    losses_reduced = forward_backward_func(
miniconda/envs/qwenvl/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 500, in forward_backward_no_pipelining
    output_tensor, num_tokens = forward_step(
.......
verl_megatron/verl/models/mcore/model_forward_fused.py", line 136, in fused_forward_qwen2_5_vl
    output_orig: CausalLMOutputForPPO = model(
......
mbridge-main/mbridge/models/qwen3_vl/model.py", line 323, in forward
    output = self.language_model(
envs/qwenvl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/envs/qwenvl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: _fused_GPTModel_forward() got an unexpected keyword argument 'visual_pos_masks'
```

In addition, there will be shape mismatch error when calculating `mrope`,
 if we pass `position_ids` in `fused_forward_qwen2_5_vl`, I tried to debug
but the shape passed here doesn't make sense, and since according to
https://github.com/volcengine/verl/blob/981d781db932ff53a0c584fd501dcd73ce2a8077/verl/models/mcore/model_forward.py#L117
it says model will calculate position_ids, I just follow the code there
to not pass the position ids, and it works both for Qwen2.5VL and Qwen3VL
without throwing further errors.

I found another issue in the original upstream codebase: the temperature
parameter doesn't get correctly passed to `_fused_GPTModel_forward`. To make
this work normally, I added **kwargs to allow models to accept additional
arbitrary kwargs and pass these **kwargs all to self.language_model, for both
the verl side (Qwen2_5VLModel) and all the vision models under mbridge

Given the current situation of functions under `verl/models/mcore/model_forward.py`
and `verl/models/mcore/model_forward_fused.py` (code duplications, several unused
and untested condition branches, as well as their allowance for arbitrary kwargs
but not getting them passed to anywhere), it is very hard to debug and maintain.
So I also refactored the functions and cleaned up the code under those 2 files
to use closure for unifying vision models and normal "GPT" (language) models,
for better maintainability.

Signed-off-by: Hollow Man <hollowman@opensuse.org>
  • Loading branch information
HollowMan6 committed Oct 23, 2025
commit de86db90a8dfb709efa1df34ebef0f60c018fb84
2 changes: 1 addition & 1 deletion .github/workflows/.deprecate/e2e_prime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
image: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2"
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ permissions:
contents: read

env:
IMAGE: "your vemlp image" # e.g. "verl-ci-cn-beijing.cr.volces.com/verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2"
IMAGE: "your vemlp image" # e.g. "verl-ci-cn-beijing.cr.volces.com/verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2"
DYNAMIC_RUNNER_URL: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" # public veFaas api

jobs:
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile.extention.awsefa
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Base Image support aws EFA
# Build Image with frameworks based on this
FROM verlai/verl:app-verl0.5-sglang0.4.6.post5-mcore0.12.2
FROM verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2

# For aws instances with EFA net interface (Sagemaker AI Pod)
# install EFA driver:
Expand Down
2 changes: 1 addition & 1 deletion docs/start/multinode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ Once the fleet is created, define a Ray cluster task, e.g. in ``ray-cluster.dsta
- PYTHONUNBUFFERED=1
- CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.2
image: verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2
commands:
- git clone https://github.com/volcengine/verl
- cd verl
Expand Down
192 changes: 59 additions & 133 deletions verl/models/mcore/model_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,105 +21,42 @@
postprocess_packed_seqs_no_padding,
preprocess_packed_seqs,
preprocess_packed_seqs_no_padding,
recover_left_padding,
remove_left_padding,
)


def gptmodel_forward(
model,
input_ids,
attention_mask,
position_ids,
sequence_parallel,
value_model=False,
pack_seqs=True,
logits_processor=None,
logits_processor_args: dict = None,
**kwargs,
):
"""Default forward pass for GPT models with optional sequence packing."""
pre_process = unwrap_model(model).pre_process
post_process = unwrap_model(model).post_process
if pack_seqs:
batch_size, seq_len = attention_mask.shape[:2]
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)
input_ids_rmpad = input_ids_rmpad.contiguous()
output_orig = model(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids,
packed_seq_params=packed_seq_params,
)
if post_process and logits_processor is not None:
args = {
k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0]
for k, v in logits_processor_args.items()
}
output_dict = logits_processor(output_orig, **args)
output = {
k: postprocess_packed_seqs(
v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process
)
for k, v in output_dict.items()
}
else:
output = postprocess_packed_seqs(
output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process
)
else:
batch_size, sequence_length = attention_mask.shape
new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(
input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process
)
output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids)
if post_process:
output = logits_processor(output, **logits_processor_args)
output = recover_left_padding(
output, new_attention_mask, attention_mask, sequence_length, post_process=post_process
)
if value_model and post_process:
output = output[..., 0]
return output
def model_forward_gen(vision_model: bool = False):
def model_forward(
model,
input_ids,
attention_mask,
position_ids,
multi_modal_inputs: dict,
logits_processor=None,
logits_processor_args: dict = None,
value_model=False,
):
"""Forward pass for models with sequence packing."""
pre_process = (
unwrap_model(model).pre_process if not vision_model else True
) # vision model always needs pre_process
post_process = unwrap_model(model).post_process

model_kwargs = {}
if "pixel_values" in multi_modal_inputs:
model_kwargs["pixel_values"] = multi_modal_inputs["pixel_values"].to(input_ids.device)
if "image_grid_thw" in multi_modal_inputs:
model_kwargs["image_grid_thw"] = multi_modal_inputs["image_grid_thw"].to(input_ids.device)

def gptmodel_forward_qwen2_5_vl(
model,
input_ids,
attention_mask,
position_ids,
sequence_parallel,
value_model=False,
pack_seqs=True,
multi_modal_inputs=None,
logits_processor=None,
logits_processor_args: dict = None,
**kwargs,
):
from megatron.core import parallel_state as mpu

assert mpu.get_context_parallel_world_size() == 1, "qwen2_5_vl's context parallel is not accurate yet"
pre_process = unwrap_model(model).pre_process
post_process = unwrap_model(model).post_process
pixel_values = (
multi_modal_inputs["pixel_values"].to(input_ids.device) if "pixel_values" in multi_modal_inputs else None
)
image_grid_thw = (
multi_modal_inputs["image_grid_thw"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs else None
)
if pack_seqs:
batch_size, seq_len = attention_mask.shape[:2]
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True)
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)
input_ids_rmpad = input_ids_rmpad.contiguous()
output_orig = model(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=None, # model will calculate position_ids
position_ids=position_ids if not vision_model else None, # vision models will calculate position_ids
packed_seq_params=packed_seq_params,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
**model_kwargs,
)

if post_process and logits_processor is not None:
args = {
k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0]
Expand All @@ -136,66 +73,55 @@ def gptmodel_forward_qwen2_5_vl(
output = postprocess_packed_seqs(
output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process
)
else:
batch_size, sequence_length = attention_mask.shape
new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(
input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process
)
output = model(
input_ids=new_input_ids,
position_ids=new_position_ids,
attention_mask=new_attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
output = recover_left_padding(
output, new_attention_mask, attention_mask, sequence_length, post_process=post_process
)
if value_model and post_process:
output = output[..., 0]
return output
if value_model and post_process:
output = output[..., 0]
return output

return model_forward


def gptmodel_forward_no_padding(
model,
input_ids,
value_model=False,
pack_seqs=True,
multi_modal_inputs: dict,
logits_processor=None,
logits_processor_args: dict = None,
**kwargs,
value_model=False,
):
"""Default forward pass for GPT models with optional sequence packing."""
pre_process = unwrap_model(model).pre_process
post_process = unwrap_model(model).post_process
if pack_seqs:
batch_size = input_ids.shape[0]
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs_no_padding(input_ids, pre_process=pre_process)
input_ids_rmpad = input_ids_rmpad.contiguous()
output_orig = model(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=None,
packed_seq_params=packed_seq_params,
)

if post_process and logits_processor is not None:
args = {
k: preprocess_packed_seqs_no_padding(v, pre_process=True)[0] for k, v in logits_processor_args.items()
}
output_dict = logits_processor(output_orig, **args)
output = {
k: postprocess_packed_seqs_no_padding(
v, packed_seq_params, input_ids, batch_size, post_process=post_process
)
for k, v in output_dict.items()
}
else:
output = postprocess_packed_seqs_no_padding(
output_orig, packed_seq_params, input_ids, batch_size, post_process=post_process
model_kwargs = {}
if "pixel_values" in multi_modal_inputs:
model_kwargs["pixel_values"] = multi_modal_inputs["pixel_values"].to(input_ids.device)
if "image_grid_thw" in multi_modal_inputs:
model_kwargs["image_grid_thw"] = multi_modal_inputs["image_grid_thw"].to(input_ids.device)

batch_size = input_ids.shape[0]
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs_no_padding(input_ids, pre_process=pre_process)
input_ids_rmpad = input_ids_rmpad.contiguous()
output_orig = model(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=None,
packed_seq_params=packed_seq_params,
**model_kwargs,
)

if post_process and logits_processor is not None:
args = {k: preprocess_packed_seqs_no_padding(v, pre_process=True)[0] for k, v in logits_processor_args.items()}
output_dict = logits_processor(output_orig, **args)
output = {
k: postprocess_packed_seqs_no_padding(
v, packed_seq_params, input_ids, batch_size, post_process=post_process
)
for k, v in output_dict.items()
}
else:
raise NotImplementedError("gptmodel_forward_no_padding only supports packed sequences")
output = postprocess_packed_seqs_no_padding(
output_orig, packed_seq_params, input_ids, batch_size, post_process=post_process
)

if value_model and post_process:
# output = output[..., 0]
Expand Down
Loading