Skip to content
Merged
52 changes: 52 additions & 0 deletions examples/llm-api/out_of_tree_example/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Out-of-tree Model Development
The file `modeling_opt.py` shows an example of how a custom model can be defined using TRT-LLM APIs without modifying the source code of TRT-LLM.

The file `main.py` shows how to run inference for such custom models using the LLM API.


## Out-of-tree Multimodal Models

For multimodal models, TRT-LLM provides `quickstart_multimodal.py` to quickly run a multimodal model that is defined within TRT-LLM. `trtllm-bench` can be used for benchmarking such models.
However, the following sections describe how to use those tools for out-of-tree models.

### Pre-requisite
To use an out-of-tree model with the quickstart example and trtllm-bench, you need to prepare the model definition files similar to a python module.
Consider the following file structure as an example:
```
modeling_custom_phi
|-- __init__.py
|-- configuration.py
|-- modeling_custom_phi.py
|-- encoder
|-- __init__.py
|-- configuration.py
|-- modeling_encoder.py
````
The files `__init__.py` should be populated with the right imports for the custom model. For example, the `modeling_custom_phi/__init__.py` can contain something like:
```
from .modeling_custom_phi import MyVLMForConditionalGeneration
from . import encoder
```

### Quickstart Example

Once the model definition files are prepared as a python module (as described above), you can use the `--custom_module_dirs` flag in `quickstart_multimodal.py` to load your model and run inference.

```
python3 quickstart_multimodal.py --model_dir ./model_ckpt --modality image --max_tokens 10 --prompt "Describe the image." --media ./demo_lower.png --image_format pil --custom_module_dirs ../modeling_custom_phi
```

### Benchmarking

Similar to the quickstart example, you can use the same CLI argument with `trtllm-bench` to benchmark a custom model.

Prepare the dataset:
```
python ./benchmarks/cpp/prepare_dataset.py --tokenizer ./model_ckpt --stdout dataset --dataset-name lmms-lab/MMMU --dataset-split test --dataset-image-key image --dataset-prompt-key "question" --num-requests 100 --output-len-dist 128,5 > mm_data.jsonl
```


Run the benchmark:
```
trtllm-bench --model ./model_ckpt --model_path ./model_ckpt throughput --dataset mm_data.jsonl --backend pytorch --num_requests 100 --max_batch_size 4 --modality image --streaming --custom_module_dirs ../modeling_custom_phi
```
46 changes: 36 additions & 10 deletions examples/llm-api/quickstart_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from quickstart_advanced import add_llm_args, setup_llm

from tensorrt_llm.inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS,
default_multimodal_input_loader)
from tensorrt_llm.inputs import default_multimodal_input_loader
from tensorrt_llm.inputs.registry import MULTIMODAL_PLACEHOLDER_REGISTRY
from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir

example_medias_and_prompts = {
"image": {
Expand Down Expand Up @@ -79,18 +80,19 @@


def add_multimodal_args(parser):
parser.add_argument("--model_type",
type=str,
choices=ALL_SUPPORTED_MULTIMODAL_MODELS,
help="Model type.")
parser.add_argument(
"--model_type",
type=str,
choices=MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types(),
help="Model type as specified in the HuggingFace model config.")
parser.add_argument("--modality",
type=str,
choices=[
"image", "video", "audio", "image_audio",
"multiple_image", "mixture_text_image"
],
default="image",
help="Media type.")
help="Media type being used for inference.")
parser.add_argument("--media",
type=str,
nargs="+",
Expand All @@ -108,6 +110,18 @@ def add_multimodal_args(parser):
type=str,
default="cpu",
help="The device to have the input on.")
parser.add_argument(
"--custom_module_dirs",
type=str,
nargs="+",
default=None,
help=
("Paths to an out-of-tree model directory which should be imported."
" This is useful to load a custom model. The directory should have a structure like:"
" <model_name>"
" ├── __init__.py"
" ├── <model_name>.py"
" └── <sub_dirs>"))
return parser


Expand Down Expand Up @@ -140,6 +154,15 @@ def parse_arguments():

def main():
args = parse_arguments()
if args.custom_module_dirs is not None:
for custom_module_dir in args.custom_module_dirs:
try:
import_custom_module_from_dir(custom_module_dir)
except Exception as e:
print(
f"Failed to import custom module from {custom_module_dir}: {e}"
)
raise e

lora_config = None
if args.load_lora:
Expand All @@ -159,16 +182,19 @@ def main():
model_type = args.model_type
else:
model_type = json.load(
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}"
open(os.path.join(str(llm._hf_model_dir),
'config.json')))['model_type']
assert model_type in MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types(), \
f"Unsupported model_type: {model_type} found!\n" \
f"Supported types: {MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()}"

# set prompts and media to example prompts and images if they are not provided
if args.prompt is None:
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
if args.media is None:
args.media = example_medias_and_prompts[args.modality]["media"]
inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer,
model_dir=llm._hf_model_dir,
model_dir=str(llm._hf_model_dir),
model_type=model_type,
modality=args.modality,
prompts=args.prompt,
Expand Down
12 changes: 10 additions & 2 deletions tensorrt_llm/_torch/models/modeling_gemma3vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
BaseWeightMapper

from ..._utils import nvtx_range
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...logger import logger
from ...sampling_params import SamplingParams
Expand Down Expand Up @@ -137,7 +139,13 @@ def forward(self, vision_outputs: torch.Tensor):


@register_auto_model("Gemma3ForConditionalGeneration")
@register_input_processor(Gemma3InputProcessor, model_type="gemma3")
@register_input_processor(
Gemma3InputProcessor,
model_type="gemma3",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={"image": "<start_of_image>"},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
class Gemma3VLM(PreTrainedModel):

def __init__(self, model_config: ModelConfig[Gemma3Config]):
Expand Down
22 changes: 20 additions & 2 deletions tensorrt_llm/_torch/models/modeling_hyperclovax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

from tensorrt_llm.inputs.multimodal import MultimodalParams

from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...logger import logger
from ...sampling_params import SamplingParams
Expand Down Expand Up @@ -961,7 +963,23 @@ def forward(self, multimodal_params: List[MultimodalParams]):


@register_auto_model("HCXVisionForCausalLM")
@register_input_processor(HCXVisionInputProcessor, model_type="hyperclovax_vlm")
@register_input_processor(
HCXVisionInputProcessor,
model_type="hyperclovax_vlm",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image":
('<im_end>\n<|im_start|>user (mime) \n'
'{"type": "image/jpeg", "filename": ""}<|im_end|>\n'
'<|im_start|>user (vector)\n<|dummy3|><|im_end|>\n'
'<|im_start|>image/aux\n'
'다음 중 ocr은 사진에서 검출된 글자이고, lens_keyword는 사진에서 추출된 '
'keyword와 bbox 위치입니다.bbox는 0~1 사이로 정규화된 [x1, y1, x2, y2]의 '
'형태입니다. 참고하여 답변하세요. '
'{"ocr": "", "lens_keywords": "", "lens_local_keywords": ""}')
},
placeholder_placement=MultimodalPlaceholderPlacement.AFTER_TEXT,
))
class HCXVisionForCausalLM(PreTrainedModel):

def __init__(self, model_config: ModelConfig):
Expand Down
12 changes: 10 additions & 2 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from tensorrt_llm.lora_manager import HfLoraLoader
from tensorrt_llm.models.convert_utils import split_matrix_tp

from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...sampling_params import SamplingParams
from ..attention_backend import AttentionMetadata
Expand Down Expand Up @@ -1173,7 +1175,13 @@ def __call__(


@register_auto_model("Llama4ForConditionalGeneration")
@register_input_processor(Llama4InputProcessor, model_type="llama4")
@register_input_processor(
Llama4InputProcessor,
model_type="llama4",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={"image": "<|image|>"},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model,
Llama4Config]):

Expand Down
12 changes: 10 additions & 2 deletions tensorrt_llm/_torch/models/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from tensorrt_llm.inputs.multimodal import MultimodalParams

from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...llmapi.utils import download_hf_model
from ...logger import logger
Expand Down Expand Up @@ -263,7 +265,13 @@ def forward(self, multimodal_params: List[MultimodalParams]):


@register_auto_model("LlavaNextForConditionalGeneration")
@register_input_processor(LlavaNextInputProcessor, model_type="llava_next")
@register_input_processor(
LlavaNextInputProcessor,
model_type="llava_next",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={"image": "<image>"},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
class LlavaNextModel(PreTrainedModel):
config_class = LlavaNextConfig

Expand Down
20 changes: 17 additions & 3 deletions tensorrt_llm/_torch/models/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
from tensorrt_llm._torch.speculative import SpecMetadata
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.inputs import (ExtraProcessedInputs, InputProcessor,
TextPrompt, register_input_processor)
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.logger import logger

Expand Down Expand Up @@ -269,8 +271,20 @@ def __call__(


@register_auto_model("Mistral3ForConditionalGeneration")
# The below informs the registry which input registry to create for this in `tensorrt_llm/llmapi/llm.py`.
@register_input_processor(Mistral3InputProcessor, model_type="mistral3")
@register_input_processor(
Mistral3InputProcessor,
model_type="mistral3",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image": "[IMG]",
},
# NOTE: for mistral3 multimodal models, it does not strictly have to be before the text.
# Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
# src/mistral_common/tokens/tokenizers/base.py#L326
# However, accuracy tests show that the model generates higher quality output when the image
# precedes the text (the relative difference can be as much as ~30% for both vLLM and TRT-LLM).
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
class Mistral3VLM(PreTrainedModel):
"""Mistral3VLM implementation for TRTLLM.

Expand Down
16 changes: 14 additions & 2 deletions tensorrt_llm/_torch/models/modeling_phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from tensorrt_llm.inputs.multimodal import MultimodalParams

from ...executor.request import LoRARequest
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...logger import logger
from ...lora_helper import LoraConfig
Expand Down Expand Up @@ -461,7 +463,17 @@ def __call__(


@register_auto_model("Phi4MMForCausalLM")
@register_input_processor(Phi4MMInputProcessor, model_type="phi4mm")
@register_input_processor(
Phi4MMInputProcessor,
model_type="phi4mm",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image": "<|image_{0}|>",
"audio": "<|audio_{0}|>",
},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
placeholders_separator="",
))
class Phi4MMForCausalLM(transformers.PreTrainedModel):

_supports_flash_attn_2 = True
Expand Down
24 changes: 21 additions & 3 deletions tensorrt_llm/_torch/models/modeling_qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

from ..._utils import nvtx_range_debug
from ...functional import RopeEmbeddingUtils, RotaryScalingType
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...logger import logger
from ...sampling_params import SamplingParams
Expand Down Expand Up @@ -645,7 +647,16 @@ def forward(


@register_auto_model("Qwen2VLForConditionalGeneration")
@register_input_processor(Qwen2VLInputProcessorBase, model_type="qwen2_vl")
@register_input_processor(
Qwen2VLInputProcessorBase,
model_type="qwen2_vl",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image": "<|vision_start|><|image_pad|><|vision_end|>",
"video": "<|vision_start|><|video_pad|><|vision_end|>"
},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
class Qwen2VLModel(Qwen2VLModelBase):

def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
Expand All @@ -657,7 +668,14 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,


@register_auto_model("Qwen2_5_VLForConditionalGeneration")
@register_input_processor(Qwen2VLInputProcessorBase, model_type="qwen2_5_vl")
@register_input_processor(
Qwen2VLInputProcessorBase,
model_type="qwen2_5_vl",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image": "<|vision_start|><|image_pad|><|vision_end|>",
"video": "<|vision_start|><|video_pad|><|vision_end|>"
}))
class Qwen2_5_VLModel(Qwen2VLModelBase):

def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
Expand Down
15 changes: 13 additions & 2 deletions tensorrt_llm/_torch/models/modeling_vila.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
PreTrainedModel)

from ..._utils import nvtx_range
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...logger import logger
from ...sampling_params import SamplingParams
Expand Down Expand Up @@ -1118,7 +1120,16 @@ def __call__(


@register_auto_model(VilaConfig.model_architecture)
@register_input_processor(VilaInputProcessor, model_type="llava_llama")
@register_input_processor(
VilaInputProcessor,
model_type="llava_llama",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image": "<image>",
"video": "<vila/video>"
},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
class VilaModel(PreTrainedModel):
config_class = VilaConfig

Expand Down
Loading