Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/models/supported-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
| DeepseekV3ForCausalLM | Yes | Yes | Yes | Yes | Yes [^1] | Yes | No | No | Yes | Yes | Yes [^2] | N/A | Yes | Yes |
| Qwen3MoeForCausalLM | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | N/A | Yes | Yes |
| Llama4ForConditionalGeneration | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes | Untested | N/A | Yes | Yes |
| GPT-OSS | Yes | Yes | Yes | Yes | No | No | Yes | No | Yes | Yes | No | N/A | Yes | Yes |
| GPT-OSS | Yes | Yes | Yes | Yes | No | No | Yes | No | Yes | Yes | No | N/A | Yes | Yes |

[^1]: Chunked Prefill for MLA can only be enabled on SM100.
[^2]: KV cache reuse for MLA can only be enabled on SM90/SM100 and in BF16/FP8 KV cache dtype.
Expand All @@ -45,12 +45,12 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl

| Model Architecture/Feature | Overlap Scheduler | CUDA Graph | Chunked Prefill | Torch Sampler | TLLM C++ Sampler | KV Cache Reuse | Logits Post Processor | EPD Disaggregated Serving | Modality |
| ---------------------------------- | ----------------- | ---------- | --------------- | ------------- | ---------------- | -------------- | --------------------- | ------------------------- | -------- |
| Gemma3ForConditionalGeneration | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No | L + I |
| Gemma3ForConditionalGeneration | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No | L + I |
| HCXVisionForCausalLM | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
| LlavaLlamaModel (VILA) | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + V |
| LlavaNextForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
| Llama4ForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
| Mistral3ForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
| Mistral3ForConditionalGeneration | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | L + I |
| Phi4MMForCausalLM | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + A |
| Qwen2VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |
| Qwen2_5_VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |
Expand Down
90 changes: 65 additions & 25 deletions tensorrt_llm/_torch/models/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
PositionalEmbeddingParams, RopeParams)
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models import modeling_pixtral
from tensorrt_llm._torch.models.modeling_multimodal_utils import \
fuse_input_embeds
from tensorrt_llm._torch.models.modeling_multimodal_utils import (
find_input_mm_embeds, fuse_input_embeds, get_multimodal_embeddings)
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
DecoderModelForCausalLM,
_load_weights_impl,
Expand All @@ -29,10 +29,12 @@
from tensorrt_llm._torch.speculative import SpecMetadata
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.inputs import (ExtraProcessedInputs, InputProcessor,
from tensorrt_llm.inputs import (BaseMultimodalInputProcessor,
ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from tensorrt_llm.inputs.multimodal import MultimodalParams
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.logger import logger

Expand Down Expand Up @@ -212,7 +214,7 @@ def __init__(
)


class Mistral3InputProcessor(InputProcessor):
class Mistral3InputProcessor(BaseMultimodalInputProcessor, InputProcessor):

def __init__(
self,
Expand Down Expand Up @@ -276,6 +278,31 @@ def __call__(

return input_ids, extra_processed_inputs

def get_vocab_size(self) -> int:
"""Return the vocab size of the model."""
# Unlike some other VLMs, mistral3's vocab size is stored in its `text_config`, not the top-level
# config.
return self.model_config.text_config.vocab_size

def get_mm_token_ids(self) -> torch.Tensor:
"""Get the IDs of all multimodal tokens (placeholders and special tokens alike)."""
return torch.tensor([
# This is the `[IMG]` token id inserted into the prompt that should be replaced with image
# embeddings.
self._processor.image_token_id,
# This is the `[IMG_BREAK]` token id at the end of every "row".
self._processor.image_break_token_id,
# This is the `[IMG_END]` token id to signify the end of an image.
self._processor.image_end_token_id,
])

def get_mm_special_token_ids(self) -> torch.Tensor:
"""Get the IDs of special multimodal tokens (placeholders not included)."""
return torch.tensor([
self._processor.image_break_token_id,
self._processor.image_end_token_id,
])


@register_auto_model("Mistral3ForConditionalGeneration")
@register_input_processor(
Expand Down Expand Up @@ -380,27 +407,12 @@ def forward(
mm_embeds = []
multimodal_params_len = len(multimodal_params)
if multimodal_params_len > 0:
pixel_values = [
x.multimodal_data["image"]["pixel_values"]
for x in multimodal_params
]
image_sizes = [
x.multimodal_data["image"]["image_sizes"]
for x in multimodal_params
]
if not (len(pixel_values) == len(image_sizes) ==
multimodal_params_len):
raise ValueError(
f"Expected as many `pixel_values` ({len(pixel_values)}) and "
f"`image_sizes` ({len(image_sizes)}) as number of multimodal parameters "
f"({multimodal_params_len}).")
image_sizes = [torch.tensor(x) for x in image_sizes]
batched_pixel_values, batched_image_sizes = self.batch_pixel_values(
pixel_values=pixel_values, image_sizes=image_sizes)
mm_embeds = [
self._get_image_features(pixel_values=batched_pixel_values,
image_sizes=batched_image_sizes)
]
mm_embeds = get_multimodal_embeddings(
encoder_forward_fn=self._vision_forward,
multimodal_params=multimodal_params[:num_context_requests],
)
mm_embeds = find_input_mm_embeds(
mm_embeds, multimodal_params[:num_context_requests])

with nvtx_range("[mistral] Fuse input embeds"):
input_ids, inputs_embeds = fuse_input_embeds(
Expand Down Expand Up @@ -440,6 +452,34 @@ def _get_sub_model_config(

return sub_model_config

# NOTE: this is defined as a separate method with this specific signature in order to be compatible
# with `get_multimodal_embeddings`.
def _vision_forward(
self,
multimodal_params: List[MultimodalParams]) -> List[torch.Tensor]:
multimodal_params_len = len(multimodal_params)
pixel_values = [
x.multimodal_data["image"]["pixel_values"]
for x in multimodal_params
]
image_sizes = [
x.multimodal_data["image"]["image_sizes"] for x in multimodal_params
]
if not (len(pixel_values) == len(image_sizes) == multimodal_params_len):
raise ValueError(
f"Expected as many `pixel_values` ({len(pixel_values)}) and "
f"`image_sizes` ({len(image_sizes)}) as number of multimodal parameters "
f"({multimodal_params_len}).")
image_sizes = [torch.tensor(x) for x in image_sizes]
batched_pixel_values, batched_image_sizes = self.batch_pixel_values(
pixel_values=pixel_values, image_sizes=image_sizes)
mm_embeds = [
self._get_image_features(pixel_values=batched_pixel_values,
image_sizes=batched_image_sizes)
]

return mm_embeds

# Original implementation:
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/
# modeling_mistral3.py#L341
Expand Down
80 changes: 54 additions & 26 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2561,22 +2561,32 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,


@pytest.mark.parametrize("modality", ["image", "video"])
@pytest.mark.parametrize("model_name,model_path", [
("llava-v1.6-mistral-7b", "llava-v1.6-mistral-7b-hf"),
("qwen2.5-vl-7b-instruct", "Qwen2.5-VL-7B-Instruct"),
])
@pytest.mark.parametrize(
"model_name,model_path,match_ratio",
[
("llava-v1.6-mistral-7b", "llava-v1.6-mistral-7b-hf", 0.8),
("qwen2.5-vl-7b-instruct", "Qwen2.5-VL-7B-Instruct", 0.8),
pytest.param(
"mistral-small-3.1-24b-instruct",
"Mistral-Small-3.1-24B-Instruct-2503",
# Lower threshold to give some wiggle room for flakiness.
0.6,
marks=pytest.mark.skip_less_device_memory(80000)),
])
def test_ptp_quickstart_multimodal_kv_cache_reuse(llm_root, llm_venv,
model_name, model_path,
modality):
modality, match_ratio):
# NOTE: individual tests need to be enabled in
# tests/integration/test_lists/qa/examples_test_list.txt

example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
test_data_root = Path(
os.path.join(llm_models_root(), "multimodals", "test_data"))
print(f"Accuracy test {model_name} {modality} mode with example inputs.")
if modality == "video" and model_name == "llava-v1.6-mistral-7b":
pytest.skip("Skipping video modality test for llava-v1.6-mistral-7b")
if modality == "video" and model_name in {
"llava-v1.6-mistral-7b", "mistral-small-3.1-24b-instruct"
}:
pytest.skip(f"Skipping video modality test for {model_name}")

num_same_requests = 3 # test kv cache reuse with multiple same requests
accuracy_inputs = {
Expand Down Expand Up @@ -2612,6 +2622,14 @@ def test_ptp_quickstart_multimodal_kv_cache_reuse(llm_root, llm_venv,
["woman", "neon", "night", "jacket", "wet"],
] * num_same_requests,
},
"mistral-small-3.1-24b-instruct": {
"image": [
[
"cloud", "dramatic", "seascape", "ocean", "turbulent",
"waves"
],
] * num_same_requests,
},
}

cmd = [
Expand Down Expand Up @@ -2651,24 +2669,33 @@ def test_ptp_quickstart_multimodal_kv_cache_reuse(llm_root, llm_venv,


@pytest.mark.parametrize("modality", ["image", "video"])
@pytest.mark.parametrize("model_name,model_path", [
("llava-v1.6-mistral-7b", "llava-v1.6-mistral-7b-hf"),
("qwen2.5-vl-7b-instruct", "Qwen2.5-VL-7B-Instruct"),
])
@pytest.mark.parametrize(
"model_name,model_path,match_ratio",
[
("llava-v1.6-mistral-7b", "llava-v1.6-mistral-7b-hf", 0.8),
("qwen2.5-vl-7b-instruct", "Qwen2.5-VL-7B-Instruct", 0.8),
pytest.param(
"mistral-small-3.1-24b-instruct",
"Mistral-Small-3.1-24B-Instruct-2503",
# Lower threshold to give some wiggle room for flakiness.
0.6,
marks=pytest.mark.skip_less_device_memory(80000)),
])
def test_ptp_quickstart_multimodal_chunked_prefill(llm_root, llm_venv,
model_name, model_path,
modality):
modality, match_ratio):
# NOTE: individual tests need to be enabled in
# tests/integration/test_lists/qa/examples_test_list.txt

example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
test_data_root = Path(
os.path.join(llm_models_root(), "multimodals", "test_data"))
print(f"Accuracy test {model_name} {modality} mode with example inputs.")
if modality == "video" and model_name == "llava-v1.6-mistral-7b":
if modality == "video" and model_name in {
"llava-v1.6-mistral-7b", "mistral-small-3.1-24b-instruct"
}:
pytest.skip("Skipping video modality test for llava-v1.6-mistral-7b")

num_same_requests = 3 # test kv cache reuse with multiple same requests
accuracy_inputs = {
"image": {
"prompt": [
Expand Down Expand Up @@ -2702,17 +2729,6 @@ def test_ptp_quickstart_multimodal_chunked_prefill(llm_root, llm_venv,
["highway", "vehicles", "traffic", "bus", "suburban"],
],
},
"qwen2-vl-7b-instruct": {
"image": [
["ocean", "waves", "atmosphere", "stormy", "clouds", "intense"],
["trees", "winding", "road", "sunny", "sky", "atmosphere"],
["traffic", "vehicles", "moderate", "lanes", "road", "cars"],
],
"video": [
["city", "night", "lights", "jacket", "wet"],
["earth", "spinning", "black"],
],
},
"qwen2.5-vl-7b-instruct": {
"image": [
["dramatic", "moody", "ocean", "stormy", "sky", "waves"],
Expand All @@ -2727,6 +2743,19 @@ def test_ptp_quickstart_multimodal_chunked_prefill(llm_root, llm_venv,
["earth", "world", "night", "lights", "cities"],
],
},
"mistral-small-3.1-24b-instruct": {
"image": [
[
"cloud", "dramatic", "seascape", "ocean", "turbulent",
"waves"
],
["scenic", "rock", "landscape", "monolith", "formation"],
[
"multi-lane", "highway", "moderate", "traffic", "flow",
"vehicles", "congestion"
],
],
},
}

cmd = [
Expand All @@ -2744,7 +2773,6 @@ def test_ptp_quickstart_multimodal_chunked_prefill(llm_root, llm_venv,
]

output = llm_venv.run_cmd(cmd, caller=check_output)
match_ratio = 4.0 / 5
for prompt_output, prompt_keywords in zip(
parse_output(output), expected_keywords[model_name][modality]):
matches = [
Expand Down
12 changes: 6 additions & 6 deletions tests/integration/test_lists/qa/llm_function_core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -646,12 +646,12 @@ test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistr
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-False]
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image]
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-image]
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video]
test_e2e.py::test_ptp_quickstart_multimodal_chunked_prefill[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image]
test_e2e.py::test_ptp_quickstart_multimodal_chunked_prefill[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video]
test_e2e.py::test_ptp_quickstart_multimodal_chunked_prefill[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-image]
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-0.8-image]
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-0.8-image]
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-0.8-video]
test_e2e.py::test_ptp_quickstart_multimodal_chunked_prefill[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-0.8-image]
test_e2e.py::test_ptp_quickstart_multimodal_chunked_prefill[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-0.8-video]
test_e2e.py::test_ptp_quickstart_multimodal_chunked_prefill[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-0.8-image]
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[audio]
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image]
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
- test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-0.6-image]
- test_e2e.py::test_ptp_quickstart_multimodal_chunked_prefill[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-0.6-image]
- condition:
ranges:
system_gpu_count:
Expand Down
Loading