Skip to content

Commit 74093e3

Browse files
authored
Merge pull request #37 from HollowMan6/kwargs
Vision models pass kwargs to self.language_model
2 parents b614d7a + 41f86ba commit 74093e3

File tree

5 files changed

+10
-0
lines changed

5 files changed

+10
-0
lines changed

mbridge/models/gemma3/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def forward(
260260
packed_seq_params: Optional[PackedSeqParams] = None,
261261
*,
262262
inference_params: Optional[BaseInferenceContext] = None,
263+
**kwargs,
263264
) -> torch.Tensor:
264265
"""Forward function of the LLaVA model.
265266
@@ -371,6 +372,7 @@ def forward(
371372
inference_context=inference_context,
372373
runtime_gather_output=runtime_gather_output,
373374
packed_seq_params=packed_seq_params,
375+
**kwargs,
374376
)
375377

376378
return output

mbridge/models/glm4_vl/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def forward(
135135
image_grid_thw: torch.Tensor = None,
136136
video_grid_thw: torch.Tensor = None,
137137
runtime_gather_output=False,
138+
**kwargs,
138139
) -> torch.Tensor:
139140
"""Forward function of the Qwen2VL model.
140141
@@ -206,6 +207,7 @@ def forward(
206207
inference_context=inference_context,
207208
runtime_gather_output=runtime_gather_output,
208209
**(extra_block_kwargs or {}),
210+
**kwargs,
209211
)
210212

211213
return output

mbridge/models/internvl3/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def forward(
173173
inference_params: InferenceParams = None,
174174
packed_seq_params: PackedSeqParams = None,
175175
image_token_index: int = -1,
176+
**kwargs,
176177
) -> torch.Tensor:
177178
use_inference_kv_cache = (
178179
inference_params is not None
@@ -223,6 +224,7 @@ def forward(
223224
labels=labels,
224225
inference_params=inference_params,
225226
packed_seq_params=packed_seq_params,
227+
**kwargs,
226228
)
227229

228230
return output

mbridge/models/qwen2_5_vl/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def forward(
209209
pixel_values_videos: torch.Tensor = None,
210210
image_grid_thw: torch.Tensor = None,
211211
video_grid_thw: torch.Tensor = None,
212+
**kwargs,
212213
) -> torch.Tensor:
213214
"""Forward function of the Qwen2VL model.
214215
@@ -353,6 +354,7 @@ def forward(
353354
# inference_params=inference_params, # currently always None
354355
packed_seq_params=packed_seq_params, # currently always None
355356
**(extra_block_kwargs or {}),
357+
**kwargs,
356358
)
357359

358360
return output

mbridge/models/qwen3_vl/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def forward(
206206
video_grid_thw: torch.Tensor = None,
207207
# cat set at dataset
208208
image_input_mask: torch.Tensor = None,
209+
**kwargs,
209210
) -> torch.Tensor:
210211
"""Forward function of the Qwen3VL model.
211212
@@ -328,6 +329,7 @@ def forward(
328329
visual_pos_masks=visual_pos_masks,
329330
deepstack_visual_embeds=deepstack_visual_embeds,
330331
**(extra_block_kwargs or {}),
332+
**kwargs,
331333
)
332334

333335
return output

0 commit comments

Comments
 (0)