diff --git a/examples/multimodal/README.md b/examples/multimodal/README.md index 4d1a6d3407..5dad64bdde 100644 --- a/examples/multimodal/README.md +++ b/examples/multimodal/README.md @@ -18,7 +18,6 @@ limitations under the License. # Multimodal Deployment Examples This directory provides example workflows and reference implementations for deploying a multimodal model using Dynamo. -The examples are based on the [llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model. ## Use the Latest Release @@ -59,11 +58,15 @@ flowchart LR decode_worker --image_url--> encode_worker encode_worker --embeddings--> decode_worker ``` -``` ```bash cd $DYNAMO_HOME/examples/multimodal -dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml +# Serve a LLaVA 1.5 7B model: +dynamo serve graphs.agg:Frontend -f ./configs/agg-llava.yaml +# Serve a Qwen2.5-VL model: +# dynamo serve graphs.agg:Frontend -f ./configs/agg-qwen.yaml +# Serve a Phi3V model: +# dynamo serve graphs.agg:Frontend -f ./configs/agg-phi3v.yaml ``` ### Client @@ -92,10 +95,13 @@ curl http://localhost:8000/v1/chat/completions \ } ], "max_tokens": 300, + "temperature": 0.0, "stream": false }' ``` +If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`. + You should see a response similar to this: ```json {"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]} @@ -162,6 +168,7 @@ curl http://localhost:8000/v1/chat/completions \ } ], "max_tokens": 300, + "temperature": 0.0, "stream": false }' ``` @@ -171,6 +178,8 @@ You should see a response similar to this: {"id": "c1774d61-3299-4aa3-bea1-a0af6c055ba8", "object": "chat.completion", "created": 1747725645, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " This image shows a passenger bus traveling down the road near power lines and trees. The bus displays a sign that says \"OUT OF SERVICE\" on its front."}, "finish_reason": "stop"}]} ``` +***Note***: disaggregation is currently only confirmed to work with LLaVA. Qwen VL and PhiV are not confirmed to be supported. + ## Deployment with Dynamo Operator These multimodal examples can be deployed to a Kubernetes cluster using [Dynamo Cloud](../../docs/guides/dynamo_deploy/dynamo_cloud.md) and the Dynamo CLI. @@ -206,8 +215,12 @@ DYNAMO_TAG=$(dynamo build graphs.agg:Frontend | grep "Successfully built" | awk # Deploy to Kubernetes export DEPLOYMENT_NAME=multimodal-agg -# For aggregated serving: -dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg.yaml +# For aggregated serving with LLaVA: +dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-llava.yaml +# For aggregated serving with Qwen2.5-VL: +# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-qwen.yaml +# For aggregated serving with Phi3V: +# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-phi3v.yaml # For disaggregated serving: # export DEPLOYMENT_NAME=multimodal-disagg # dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/disagg.yaml @@ -244,8 +257,11 @@ curl localhost:8000/v1/chat/completions \ } ], "max_tokens": 300, + "temperature": 0.0, "stream": false }' ``` +If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`. + For more details on managing deployments, testing, and troubleshooting, please refer to the [Operator Deployment Guide](../../docs/guides/dynamo_deploy/operator_deployment.md). diff --git a/examples/multimodal/components/decode_worker.py b/examples/multimodal/components/decode_worker.py index 59ac84a162..eb627ae9d5 100644 --- a/examples/multimodal/components/decode_worker.py +++ b/examples/multimodal/components/decode_worker.py @@ -24,8 +24,8 @@ from components.disagg_router import PyDisaggregatedRouter from components.encode_worker import VllmEncodeWorker from components.prefill_worker import VllmPrefillWorker -from transformers import LlavaForConditionalGeneration from utils.logging import check_required_workers +from utils.model import construct_mm_data, get_vision_embeddings_info from utils.nixl import NixlMetadataStore from utils.prefill_queue import PrefillQueue from utils.protocol import ( @@ -117,6 +117,11 @@ async def async_init(self): ) runtime = dynamo_context["runtime"] + embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info( + self.engine_args.model, self.engine_args.num_patches + ) + logger.debug(f"Embeddings shape: {embeddings_shape}") + self.embedding_size = embeddings_shape[1] if self.do_remote_prefill: metadata = self.engine_client.nixl_metadata @@ -133,18 +138,7 @@ async def async_init(self): await self.disaggregated_router.async_init() else: self.disaggregated_router = None - - model = LlavaForConditionalGeneration.from_pretrained( - self.engine_args.model, - device_map="auto", - torch_dtype=torch.bfloat16, - ).eval() - vision_tower = model.vision_tower - self.embedding_size = ( - vision_tower.vision_model.embeddings.position_embedding.num_embeddings - ) else: - EMBEDDINGS_SHAPE = (1, 577, 4096) EMBEDDINGS_DTYPE = torch.float16 EMBEDDINGS_DEVICE = "cuda" @@ -161,7 +155,7 @@ async def async_init(self): # Create a longer-lived buffer for receiving the image embeddings. embeddings = torch.empty( - EMBEDDINGS_SHAPE, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE + embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE ) descriptor = connect.Descriptor(embeddings) # Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically). @@ -206,13 +200,15 @@ async def generate(self, request: vLLMMultimodalRequest): multi_modal_data, remote_prefill_params, ) = await self.remote_prefill(request) - else: ( prompt_ids, multi_modal_data, remote_prefill_params, ) = await self.local_prefill(request) + logger.debug(f"Prompt ids: {prompt_ids}") + logger.debug(f"Multi modal data: {multi_modal_data}") + logger.debug(f"Remote prefill params: {remote_prefill_params}") # rust HTTP requires Delta streaming request.sampling_params.output_kind = RequestOutputKind.DELTA @@ -227,7 +223,7 @@ async def generate(self, request: vLLMMultimodalRequest): remote_prefill_params=remote_prefill_params, ): logger.debug( - f"Yeilding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}" + f"Yielding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}" ) yield MyRequestOutput( request_id=response.request_id, @@ -294,7 +290,9 @@ async def local_prefill(self, request: vLLMMultimodalRequest) -> tuple: "Aggregated: embedding data from encode worker provided via multi-modal data to decode model." ) # When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker. - multi_modal_data = {"image": embeddings} + multi_modal_data = construct_mm_data( + self.engine_args.model, encode_output, embeddings, self.embeddings_dtype + ) return prompt_ids, multi_modal_data, remote_prefill_params @@ -353,17 +351,16 @@ async def remote_prefill(self, request: vLLMMultimodalRequest) -> tuple: # As a workaround, here we manually insert some placeholder dummy tokens based on the embedding size # so that decode worker can pre-allocate the memory with the correct size. # The structure of the prompt will be like: "\nUSER: \n\nASSISTANT:". - # Since the "" token is included in the prompt, only need to insert (embedding_size - 1) dummy tokens after the image token. - IMAGE_TOKEN_ID = 32000 + # Since the "" token is included in the prompt, only need to insert embedding_size dummy tokens after the image token. DUMMY_TOKEN_ID = 0 # Find the index of the image token in the prompt token ids image_token_index = request.engine_prompt["prompt_token_ids"].index( - IMAGE_TOKEN_ID + self.engine_args.image_token_id ) dummy_token_index = image_token_index + 1 prompt_ids = ( request.engine_prompt["prompt_token_ids"][:dummy_token_index] - + [DUMMY_TOKEN_ID] * (self.embedding_size - 1) + + [DUMMY_TOKEN_ID] * self.embedding_size + request.engine_prompt["prompt_token_ids"][dummy_token_index:] ) logger.debug( diff --git a/examples/multimodal/components/encode_worker.py b/examples/multimodal/components/encode_worker.py index a3dd1037d2..a28c95e5b6 100644 --- a/examples/multimodal/components/encode_worker.py +++ b/examples/multimodal/components/encode_worker.py @@ -26,7 +26,8 @@ import httpx import torch from PIL import Image -from transformers import AutoImageProcessor, LlavaForConditionalGeneration +from transformers import AutoImageProcessor +from utils.model import load_vision_model from utils.protocol import EncodeRequest, EncodeResponse from utils.vllm import parse_vllm_args @@ -66,10 +67,7 @@ def __init__(self) -> None: self.image_processor = AutoImageProcessor.from_pretrained( self.MODEL_ID, trust_remote_code=True ) - - self.vision_model = LlavaForConditionalGeneration.from_pretrained( - self.MODEL_ID, device_map="auto", torch_dtype=torch.float16 - ).eval() + self.vision_model = load_vision_model(self.MODEL_ID) self._image_cache: dict[str, Image.Image] = {} self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM) @@ -167,17 +165,32 @@ async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]: logger.debug(f"Processing image for request: {{ id: {request_id} }}") image_embeds = self.image_processor(images=image, return_tensors="pt") + # Add a batch dimension to everything + for item in image_embeds: + image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE) + logger.debug(f"Image embeds: {image_embeds}") + + image_grid_thw = ( + image_embeds["image_grid_thw"].tolist() + if "image_grid_thw" in image_embeds + else None + ) + image_sizes = ( + image_embeds["image_sizes"].tolist() + if "image_sizes" in image_embeds + else [image.size] + ) + logger.debug( + f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}" + ) with torch.no_grad(): - logger.debug(f"Vision model device: {self.vision_model.device}") - vision_outputs = self.vision_model.vision_tower( - image_embeds["pixel_values"].to(self.vision_model.device) - ) - logger.debug("Vision model completed.") - - embeddings = vision_outputs.last_hidden_state - embeddings = self.vision_model.multi_modal_projector(embeddings) - + embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds) + if isinstance(embeddings, tuple) or isinstance(embeddings, list): + # The result multimodal_embeddings may be a list or tuple of tensors, with each + # tensor corresponding to a multimodal data item (image or video). + # TODO: for multi-image support, this result will contain multiple tensors. + embeddings = embeddings[0].unsqueeze(0) logger.debug( f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}." ) @@ -201,6 +214,8 @@ async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]: yield EncodeResponse( request_id=request.request_id, + image_grid_thw=image_grid_thw, + image_sizes=image_sizes, ).model_dump_json() except Exception as e: logger.error(f"Error processing request {request_id}: {e}") diff --git a/examples/multimodal/components/prefill_worker.py b/examples/multimodal/components/prefill_worker.py index f1f34c9499..7347a6c28e 100644 --- a/examples/multimodal/components/prefill_worker.py +++ b/examples/multimodal/components/prefill_worker.py @@ -25,6 +25,7 @@ from components.encode_worker import VllmEncodeWorker from pydantic import BaseModel from utils.logging import check_required_workers +from utils.model import construct_mm_data, get_vision_embeddings_info from utils.nixl import NixlMetadataStore from utils.prefill_queue import PrefillQueue from utils.protocol import EncodeRequest, EncodeResponse @@ -39,9 +40,6 @@ logger = logging.getLogger(__name__) -# Constants for the shape and dtype of the embeddings tensor. -EMBEDDINGS_SHAPE = (1, 577, 4096) -EMBEDDINGS_DTYPE = torch.float16 EMBEDDINGS_DEVICE = "cuda" @@ -113,9 +111,12 @@ async def async_init(self): await self._connector.initialize() # Create a longer-lived buffer for receiving the image embeddings. + embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info( + self.engine_args.model, self.engine_args.num_patches + ) embeddings = torch.empty( - EMBEDDINGS_SHAPE, - dtype=EMBEDDINGS_DTYPE, + embeddings_shape, + dtype=self.embeddings_dtype, device=EMBEDDINGS_DEVICE, ) descriptor = connect.Descriptor(embeddings) @@ -248,10 +249,11 @@ async def generate(self, request: RemotePrefillRequest): # To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache, # some placeholder dummy tokens are inserted based on the embedding size in the worker.py. # TODO: make this more flexible/model-dependent - IMAGE_TOKEN_ID = 32000 embedding_size = embeddings.shape[1] - padding_size = embedding_size - 1 - image_token_index = request.prompt_token_ids.index(IMAGE_TOKEN_ID) + padding_size = embedding_size + image_token_index = request.prompt_token_ids.index( + self.engine_args.image_token_id + ) dummy_token_index = image_token_index + 1 prompt_token_ids = ( request.prompt_token_ids[:dummy_token_index] @@ -262,7 +264,12 @@ async def generate(self, request: RemotePrefillRequest): request_id=request_id, prompt=TokensPrompt( prompt_token_ids=prompt_token_ids, - multi_modal_data={"image": embeddings}, + multi_modal_data=construct_mm_data( + self.engine_args.model, + encode_output, + embeddings, + self.embeddings_dtype, + ), ), sampling_params=sampling_params, remote_prefill_params=remote_prefill_params, diff --git a/examples/multimodal/components/processor.py b/examples/multimodal/components/processor.py index b1628a63a4..d45e5ab42b 100644 --- a/examples/multimodal/components/processor.py +++ b/examples/multimodal/components/processor.py @@ -188,9 +188,19 @@ async def _generate_responses( # The generate endpoint will be used by the frontend to handle incoming requests. @endpoint() async def generate(self, raw_request: MultiModalRequest): - prompt = str(self.engine_args.prompt_template).replace( - "", raw_request.messages[0].content[0].text - ) + # Ensure the configured template includes the placeholder + template = self.engine_args.prompt_template + if "" not in template: + raise ValueError("prompt_template must contain '' placeholder") + + # Safely extract user text + try: + user_text = raw_request.messages[0].content[0].text + except (IndexError, AttributeError) as e: + raise ValueError(f"Invalid message structure: {e}") + + prompt = template.replace("", user_text) + msg = { "role": "user", "content": prompt, @@ -201,6 +211,7 @@ async def generate(self, raw_request: MultiModalRequest): messages=[msg], stream=raw_request.stream, max_tokens=raw_request.max_tokens, + temperature=raw_request.temperature, request_id=str(uuid.uuid4()), ) image_url = None diff --git a/examples/multimodal/configs/agg.yaml b/examples/multimodal/configs/agg-llava.yaml similarity index 96% rename from examples/multimodal/configs/agg.yaml rename to examples/multimodal/configs/agg-llava.yaml index 344a6e46c1..c0c7346525 100644 --- a/examples/multimodal/configs/agg.yaml +++ b/examples/multimodal/configs/agg-llava.yaml @@ -26,6 +26,8 @@ VllmDecodeWorker: enforce-eager: true max-num-batched-tokens: 16384 enable-prefix-caching: true + image-token-id: 32000 + num-patches: 576 router: random tensor-parallel-size: 1 ServiceArgs: diff --git a/examples/multimodal/configs/agg-phi3v.yaml b/examples/multimodal/configs/agg-phi3v.yaml new file mode 100644 index 0000000000..bc794ae546 --- /dev/null +++ b/examples/multimodal/configs/agg-phi3v.yaml @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +Common: + model: microsoft/Phi-3.5-vision-instruct + block-size: 64 + max-model-len: 4096 + trust-remote-code: true + +Processor: + router: round-robin + prompt-template: "<|user|>\n<|image_1|>\n<|end|>\n<|assistant|>\n" + common-configs: [model, block-size, max-model-len, trust-remote-code] + +VllmDecodeWorker: + enforce-eager: true + max-num-batched-tokens: 16384 + max-num-seqs: 2 + mm-processor-kwargs: + num_crops: 16 + enable-prefix-caching: true + image-token-id: 32000 + num-patches: 757 + router: random + tensor-parallel-size: 1 + ServiceArgs: + workers: 1 + resources: + gpu: '1' + common-configs: [model, block-size, max-model-len, trust-remote-code] + +VllmEncodeWorker: + tensor-parallel-size: 1 + router: random + ServiceArgs: + workers: 1 + resources: + gpu: '1' + common-configs: [model] diff --git a/examples/multimodal/configs/agg-qwen.yaml b/examples/multimodal/configs/agg-qwen.yaml new file mode 100644 index 0000000000..324a4ffc57 --- /dev/null +++ b/examples/multimodal/configs/agg-qwen.yaml @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +Common: + model: Qwen/Qwen2.5-VL-7B-Instruct + block-size: 64 + max-model-len: 4096 + +Processor: + router: round-robin + prompt-template: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n" + common-configs: [model, block-size, max-model-len] + +VllmDecodeWorker: + enforce-eager: true + max-num-batched-tokens: 16384 + max-num-seqs: 5 + mm-processor-kwargs: + min_pixels: 784 + max_pixels: 1003520 + fps: 1 + enable-prefix-caching: true + image-token-id: 151655 + num-patches: 345 + router: random + tensor-parallel-size: 1 + ServiceArgs: + workers: 1 + resources: + gpu: '1' + common-configs: [model, block-size, max-model-len] + +VllmEncodeWorker: + tensor-parallel-size: 1 + router: random + ServiceArgs: + workers: 1 + resources: + gpu: '1' + common-configs: [model] diff --git a/examples/multimodal/configs/disagg.yaml b/examples/multimodal/configs/disagg.yaml index 6c6fbbb200..bbfd13ac02 100644 --- a/examples/multimodal/configs/disagg.yaml +++ b/examples/multimodal/configs/disagg.yaml @@ -16,6 +16,8 @@ Common: model: llava-hf/llava-1.5-7b-hf block-size: 64 max-model-len: 4096 + image-token-id: 32000 + num-patches: 576 kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}' Processor: @@ -32,7 +34,7 @@ VllmDecodeWorker: workers: 1 resources: gpu: '1' - common-configs: [model, block-size, max-model-len, kv-transfer-config] + common-configs: [model, block-size, image-token-id, max-model-len, num-patches, kv-transfer-config] VllmPrefillWorker: max-num-batched-tokens: 16384 @@ -40,7 +42,7 @@ VllmPrefillWorker: workers: 1 resources: gpu: '1' - common-configs: [model, block-size, max-model-len, kv-transfer-config] + common-configs: [model, block-size, image-token-id, max-model-len, num-patches, kv-transfer-config] VllmEncodeWorker: tensor-parallel-size: 1 diff --git a/examples/multimodal/utils/model.py b/examples/multimodal/utils/model.py new file mode 100644 index 0000000000..c72e0a518f --- /dev/null +++ b/examples/multimodal/utils/model.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Dict, Tuple + +import torch +from transformers import AutoConfig +from utils.protocol import EncodeResponse +from vllm import AsyncEngineArgs +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.worker.worker import Worker + +logger = logging.getLogger(__name__) + + +def load_vision_model(model_id: str) -> torch.nn.Module: + """ + Load a vision model from a HuggingFace model ID. + """ + engine_args = AsyncEngineArgs(model=model_id, trust_remote_code=True) + + engine_config = engine_args.create_engine_config() + distributed_init_method = get_distributed_init_method(get_ip(), get_open_port()) + worker = Worker( + vllm_config=engine_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + is_driver_worker=True, + ) + # Initialize the worker. + worker.init_device() + worker.load_model() + return worker.model_runner.model + + +def get_vision_embeddings_info( + model_id: str, num_patches: int +) -> Tuple[Tuple[int, int, int], torch.dtype]: + """Calculate vision embeddings size and dtype using model config + Returns a tuple of (batch_size, num_patches, hidden_dim), dtype. + """ + config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + assert num_patches > 0, "Number of patches must be positive" + if not hasattr(config, "torch_dtype"): + raise ValueError("Model config missing required 'torch_dtype' attribute") + if not hasattr(config, "hidden_size"): + logger.warning( + "Model config missing required 'hidden_size' attribute, using 4096" + ) + hidden_size = 4096 + else: + hidden_size = config.hidden_size + return (1, num_patches, hidden_size), config.torch_dtype + + +def construct_mm_data( + model: str, + encode_output: EncodeResponse, + image_embeds: torch.Tensor, + embeddings_dtype: torch.dtype, +) -> Dict[str, torch.Tensor | Dict[str, Any]]: + """Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings""" + image_embeds = image_embeds.to(embeddings_dtype) + if "Qwen2" in model: + return { + "image": { + "image_embeds": image_embeds.squeeze(0), + "image_grid_thw": torch.tensor(encode_output.image_grid_thw).squeeze(0), + } + } + elif "MiniCPM-V" in model: + return { + "image": { + "image_embeds": image_embeds, + "image_sizes": encode_output.image_sizes, + } + } + else: + return {"image": image_embeds} diff --git a/examples/multimodal/utils/protocol.py b/examples/multimodal/utils/protocol.py index 6613c0680a..d72807e47a 100644 --- a/examples/multimodal/utils/protocol.py +++ b/examples/multimodal/utils/protocol.py @@ -119,6 +119,7 @@ class MultiModalRequest(BaseModel): model: str messages: List[ChatMessage] max_tokens: Optional[int] = None + temperature: Optional[float] = None stream: Optional[bool] = True @@ -141,6 +142,8 @@ class EncodeRequest(BaseModel): class EncodeResponse(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) request_id: str + image_grid_thw: Optional[List[Any]] = None + image_sizes: Optional[List[Any]] = None class MyRequestOutput(BaseModel): diff --git a/examples/multimodal/utils/vllm.py b/examples/multimodal/utils/vllm.py index 7b6b1d888c..f98e2ac065 100644 --- a/examples/multimodal/utils/vllm.py +++ b/examples/multimodal/utils/vllm.py @@ -51,6 +51,18 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: default=3, help="Maximum queue size for remote prefill. If the prefill queue size is greater than this value, prefill phase of the incoming request will be executed locally.", ) + parser.add_argument( + "--image-token-id", + type=int, + default=32000, + help="Image token ID used to represent image patches in the token sequence", + ) + parser.add_argument( + "--num-patches", + type=int, + default=576, + help="Number of patches the input image is divided into (must be positive)", + ) parser.add_argument( "--prompt-template", type=str, @@ -66,4 +78,6 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: engine_args.max_local_prefill_length = args.max_local_prefill_length engine_args.max_prefill_queue_size = args.max_prefill_queue_size engine_args.prompt_template = args.prompt_template + engine_args.num_patches = args.num_patches + engine_args.image_token_id = args.image_token_id return engine_args