diff --git a/examples/multimodal/README.md b/examples/multimodal/README.md index 6b7963cecd..fbadd97001 100644 --- a/examples/multimodal/README.md +++ b/examples/multimodal/README.md @@ -59,12 +59,14 @@ flowchart LR pd_worker --> encode_worker ``` -***Note*** Only the LLaVA 1.5 7B model is supported. Qwen2.5-VL and Phi3V support will be added in the future. +***Note*** Aggregated serving supports LLaVA 1.5 7B and Qwen2.5-VL-7B-Instruct today. Phi3V support will be added in the future. Disaggregated serving is currently only confirmed for LLaVA (see note below). ```bash cd $DYNAMO_HOME/examples/multimodal # Serve a LLaVA 1.5 7B model: bash launch/agg.sh --model llava-hf/llava-1.5-7b-hf +# Serve a Qwen2.5-VL model: +bash launch/agg.sh --model Qwen/Qwen2.5-VL-7B-Instruct ``` ### Client @@ -98,6 +100,8 @@ curl http://localhost:8080/v1/chat/completions \ }' ``` +If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. + You should see a response similar to this: ```json {"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]} diff --git a/examples/multimodal/components/encode_worker.py b/examples/multimodal/components/encode_worker.py index 904434c33f..09c222199a 100644 --- a/examples/multimodal/components/encode_worker.py +++ b/examples/multimodal/components/encode_worker.py @@ -21,9 +21,8 @@ import sys from typing import AsyncIterator, Tuple -import torch import uvloop -from transformers import AutoImageProcessor, LlavaForConditionalGeneration +from transformers import AutoImageProcessor from vllm.engine.arg_utils import AsyncEngineArgs from vllm.utils import FlexibleArgumentParser @@ -33,7 +32,9 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) from utils.args import Config, base_parse_args, parse_endpoint +from utils.encode_utils import encode_image_embeddings, get_encoder_components from utils.image_loader import ImageLoader +from utils.model import load_vision_model from utils.protocol import MyRequestOutput, vLLMMultimodalRequest configure_dynamo_logging() @@ -70,13 +71,14 @@ def __init__( self.image_processor = AutoImageProcessor.from_pretrained( self.model, trust_remote_code=True ) - # self.vision_model = load_vision_model(self.model) - self.vision_model = LlavaForConditionalGeneration.from_pretrained( - self.model, device_map="auto", torch_dtype=torch.float16 - ).eval() - + self.vision_model = load_vision_model(self.model) self.min_workers = 1 + # Get encoder components for the model + self.vision_encoder, self.projector = get_encoder_components( + self.model, self.vision_model + ) + def cleanup(self): pass @@ -108,49 +110,26 @@ async def generate( logger.debug(f"Processing image for request: {{ id: {request_id} }}") image_embeds = self.image_processor(images=image, return_tensors="pt") - # [gluo NOTE] The commented section is for VLM generalization support, - # will use more generic approach once utils/model.py is fixed, - # see utils/models.py for details. - # # Add a batch dimension to everything - # for item in image_embeds: - # image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE) - # logger.debug(f"Image embeds: {image_embeds}") - - # image_grid_thw = ( - # image_embeds["image_grid_thw"].tolist() - # if "image_grid_thw" in image_embeds - # else None - # ) - # image_sizes = ( - # image_embeds["image_sizes"].tolist() - # if "image_sizes" in image_embeds - # else [image.size] - # ) - # logger.debug( - # f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}" - # ) - - # with torch.no_grad(): - # embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds) - # if isinstance(embeddings, tuple) or isinstance(embeddings, list): - # # The result multimodal_embeddings may be a list or tuple of tensors, with each - # # tensor corresponding to a multimodal data item (image or video). - # # TODO: for multi-image support, this result will contain multiple tensors. - # embeddings = embeddings[0].unsqueeze(0) - # logger.debug( - # f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}." - # ) - - with torch.no_grad(): - logger.debug(f"Vision model device: {self.vision_model.device}") - vision_outputs = self.vision_model.vision_tower( - image_embeds["pixel_values"].to(self.vision_model.device) - ) - logger.debug("Vision model completed.") - - embeddings = vision_outputs.last_hidden_state - embeddings = self.vision_model.multi_modal_projector(embeddings) + # Encode the image embeddings using model-specific encoder + embeddings = encode_image_embeddings( + model_name=self.model, + image_embeds=image_embeds, + vision_encoder=self.vision_encoder, + projector=self.projector, + ) + + image_grid_thw = ( + image_embeds["image_grid_thw"].tolist() + if "image_grid_thw" in image_embeds + else None + ) + logger.debug( + f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}" + ) + + request.image_grid_thw = image_grid_thw + request.embeddings_shape = tuple(embeddings.shape) descriptor = connect.Descriptor(embeddings) with self._connector.create_readable(descriptor) as readable: diff --git a/examples/multimodal/components/worker.py b/examples/multimodal/components/worker.py index 5b0a9faf95..a549088158 100644 --- a/examples/multimodal/components/worker.py +++ b/examples/multimodal/components/worker.py @@ -24,7 +24,6 @@ import torch import uvloop -from transformers import AutoImageProcessor from vllm.distributed.kv_events import ZmqEventPublisher from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs.data import TokensPrompt @@ -47,6 +46,7 @@ parse_endpoint, ) from utils.image_loader import ImageLoader +from utils.model import construct_mm_data from utils.protocol import MyRequestOutput, vLLMMultimodalRequest configure_dynamo_logging() @@ -245,37 +245,15 @@ async def async_init(self, runtime: DistributedRuntime): .client() ) - EMBEDDINGS_DTYPE = torch.float16 - EMBEDDINGS_DEVICE = "cpu" + self.EMBEDDINGS_DTYPE = torch.float16 + self.EMBEDDINGS_DEVICE = "cpu" # Create and initialize a dynamo connector for this worker. # We'll needs this to move data between this worker and remote workers efficiently. parsed_namespace, _, _ = parse_endpoint(self.endpoint) self._connector = connect.Connector() await self._connector.initialize() - # embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info( - # self.engine_args.model, self.engine_args.num_patches - # ) - # [gluo NOTE] Hardcoded for now, will use more generic approach once utils/model.py - # is fixed, see utils/models.py for details. - embeddings_shape = (1, 577, 4096) - logger.debug(f"Embeddings shape: {embeddings_shape}") - self.embedding_size = embeddings_shape[1] - - embeddings = torch.empty( - embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE - ) - - descriptor = connect.Descriptor(embeddings) - - # Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically). - # descriptor.register_memory(self._connector) - self._embeddings_descriptor = (embeddings, descriptor) - self.image_loader = ImageLoader() - self.image_processor = AutoImageProcessor.from_pretrained( - self.engine_args.model, trust_remote_code=True - ) logger.info("VllmPDWorker has been initialized") @@ -288,10 +266,18 @@ async def generate(self, request: vLLMMultimodalRequest): request = vLLMMultimodalRequest.model_validate(request) logger.debug(f"Received PD request: {{ id: {request.request_id} }}.") - if request.image_url is None: - # Process embeddings using the connector - embeddings, descriptor = self._embeddings_descriptor + embeddings, descriptor = None, None + + # Process embeddings using the connector + # Create a descriptor based on the embedding shape. + embeddings = torch.empty( + request.embeddings_shape, + dtype=self.EMBEDDINGS_DTYPE, + device=self.EMBEDDINGS_DEVICE, + ) + descriptor = connect.Descriptor(embeddings) + if request.image_url is None: if descriptor is None: raise RuntimeError( "Descriptor is None in PD worker - cannot process embeddings" @@ -301,15 +287,17 @@ async def generate(self, request: vLLMMultimodalRequest): request.serialized_request, descriptor ) await read_op.wait_for_completion() - logger.debug(f"in PD worker, image features: {embeddings}") - multi_modal_data = embeddings + multi_modal_data = construct_mm_data( + self.engine_args.model, + embeddings, + self.EMBEDDINGS_DTYPE, + request.image_grid_thw, + ) else: # Use PIL image instead of image embeddings - multi_modal_data = await self.image_loader.load_image(request.image_url) - # multi_modal_data = self.image_processor(images=image, return_tensors="pt")["pixel_values"].to(dtype=torch.float16) - # image input is expected to be (image_num, channel, height, width) - # logger.info(f"Image features shape: {multi_modal_data.shape}") - # multi_modal_data = multi_modal_data.unsqueeze(0) + multi_modal_data = { + "image": await self.image_loader.load_image(request.image_url) + } # Remove the image features from the request as they are not required request.image_url = None @@ -331,7 +319,7 @@ async def generate(self, request: vLLMMultimodalRequest): gen = self.engine_client.generate( prompt=TokensPrompt( prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"], - multi_modal_data={"image": multi_modal_data}, + multi_modal_data=multi_modal_data, ), sampling_params=pd_request.sampling_params, request_id=pd_request.request_id, diff --git a/examples/multimodal/utils/encode_utils.py b/examples/multimodal/utils/encode_utils.py new file mode 100644 index 0000000000..0b0f97efaf --- /dev/null +++ b/examples/multimodal/utils/encode_utils.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Dict, Optional + +import torch + +from .model import SupportedModels + +logger = logging.getLogger(__name__) + + +def get_qwen_image_features( + vision_encoder: torch.nn.Module, image_embeds: Dict[str, Any] +) -> torch.Tensor: + """ + Extract image features using Qwen-style vision encoder. + + Args: + vision_encoder: The vision encoder model + image_embeds: Dictionary containing pixel values and grid information + + Returns: + Processed image features tensor + + Raises: + ValueError: If grid_thw is not provided for Qwen model + """ + pixel_values = image_embeds["pixel_values"].to(vision_encoder.device) + + grid_thw = image_embeds.get("image_grid_thw", None) + if grid_thw is not None: + grid_thw = grid_thw.to(vision_encoder.device) + logger.debug(f"Qwen grid_thw shape: {grid_thw.shape}") + else: + raise ValueError("grid_thw is not provided") + + return ( + vision_encoder.get_image_features(pixel_values, grid_thw) # type: ignore + if grid_thw is not None + else vision_encoder.get_image_features(pixel_values) # type: ignore + ) + + +def encode_image_embeddings( + model_name: str, + image_embeds: Dict[str, Any], + vision_encoder: torch.nn.Module, + projector: Optional[torch.nn.Module] = None, +) -> torch.Tensor: + """ + Encode image embeddings using the appropriate model-specific encoder. + + Args: + model_name: The model identifier + image_embeds: Dictionary containing processed image data + vision_encoder: The vision encoder module + projector: The multimodal projector (required for LLaVA-style models) + + Returns: + Encoded embeddings tensor with normalized shape + + Raises: + ValueError: If projector is missing for LLaVA models + NotImplementedError: If model is not supported + """ + with torch.no_grad(): + # Route through the correct encoder based on model + if model_name == SupportedModels.LLAVA_1_5_7B: + pixel_values = image_embeds["pixel_values"].to(vision_encoder.device) + vision_outputs = vision_encoder(pixel_values) + + if projector is None: + raise ValueError(f"Projector not found for LLaVA model: {model_name}") + + embeddings = projector(vision_outputs.last_hidden_state) + + elif model_name == SupportedModels.QWEN_2_5_VL_7B: + embeddings = get_qwen_image_features(vision_encoder, image_embeds) + + else: + raise NotImplementedError(f"Model not supported: {model_name}") + + # Normalize output shape + if isinstance(embeddings, (tuple, list)): + embeddings = embeddings[0] + embeddings = embeddings.unsqueeze(0) if embeddings.ndim == 2 else embeddings + + return embeddings + + +def get_encoder_components( + model_name: str, vision_model: torch.nn.Module +) -> tuple[Any, Optional[Any]]: + """ + Get the appropriate vision encoder and projector components for a given model. + + Args: + model_name: The model identifier + vision_model: The loaded vision model + + Returns: + Tuple of (vision_encoder, projector) where types depend on the model + + Raises: + NotImplementedError: If model is not supported + """ + if model_name == SupportedModels.LLAVA_1_5_7B: + vision_encoder = vision_model.vision_tower + projector = getattr(vision_model, "multi_modal_projector", None) + return vision_encoder, projector + + elif model_name == SupportedModels.QWEN_2_5_VL_7B: + vision_encoder = vision_model + projector = None + return vision_encoder, projector + + else: + raise NotImplementedError(f"Model not supported: {model_name}") diff --git a/examples/multimodal/utils/image_loader.py b/examples/multimodal/utils/image_loader.py index 403d002151..fa313a65df 100644 --- a/examples/multimodal/utils/image_loader.py +++ b/examples/multimodal/utils/image_loader.py @@ -31,7 +31,9 @@ class ImageLoader: def __init__(self, cache_size: int = CACHE_SIZE_MAXIMUM): self._http_timeout = 30.0 - self._http_client = httpx.AsyncClient(timeout=self._http_timeout) + self._http_client = httpx.AsyncClient( + timeout=self._http_timeout, follow_redirects=True + ) self._image_cache: dict[str, Image.Image] = {} self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size) diff --git a/examples/multimodal/utils/model.py b/examples/multimodal/utils/model.py index 3f7338f1b6..370b8039cd 100644 --- a/examples/multimodal/utils/model.py +++ b/examples/multimodal/utils/model.py @@ -14,61 +14,47 @@ # limitations under the License. import logging -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch -from transformers import AutoConfig -from utils.protocol import EncodeResponse -from vllm import AsyncEngineArgs -from vllm.utils import get_distributed_init_method, get_ip, get_open_port -from vllm.worker.worker import Worker +from transformers import AutoConfig, AutoModel -# from transformers import AutoImageProcessor, LlavaForConditionalGeneration -# from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor +logger = logging.getLogger(__name__) -logger = logging.getLogger(__name__) +class SupportedModels: + """Supported multimodal model identifiers""" + + LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf" + QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct" + LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf" -# [gluo NOTE] in vLLM v1, Worker() usage below will results in NotImplementedError, -# must find another way to properly load the vision model given the model name (model_id). def load_vision_model(model_id: str) -> torch.nn.Module: """ Load a vision model from a HuggingFace model ID. """ - engine_args = AsyncEngineArgs(model=model_id, trust_remote_code=True) - - engine_config = engine_args.create_engine_config() - distributed_init_method = get_distributed_init_method(get_ip(), get_open_port()) - worker = Worker( - vllm_config=engine_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - is_driver_worker=True, + model = AutoModel.from_pretrained( + model_id, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True ) - # Initialize the worker. - worker.init_device() - worker.load_model() - return worker.model_runner.model - # model = LlavaForConditionalGeneration.from_pretrained( - # model_id, device_map="auto", torch_dtype=torch.float16 - # ).eval() - - # model = Qwen2_5_VLForConditionalGeneration.from_pretrained( - # model_id, torch_dtype="auto", device_map="auto" - # ).eval() - # return model + return model def get_vision_embeddings_info( - model_id: str, num_patches: int + model_id: str, ) -> Tuple[Tuple[int, int, int], torch.dtype]: """Calculate vision embeddings size and dtype using model config - Returns a tuple of (batch_size, num_patches, hidden_dim), dtype. + Returns a tuple of (batch_size, seq_len, hidden_dim), dtype. """ config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) - assert num_patches > 0, "Number of patches must be positive" + + if model_id == SupportedModels.LLAVA_1_5_7B: + seq_len = 577 + elif model_id == SupportedModels.QWEN_2_5_VL_7B: + seq_len = 345 + else: + seq_len = 0 + if not hasattr(config, "torch_dtype"): raise ValueError("Model config missing required 'torch_dtype' attribute") if not hasattr(config, "hidden_size"): @@ -78,29 +64,27 @@ def get_vision_embeddings_info( hidden_size = 4096 else: hidden_size = config.hidden_size - return (1, num_patches, hidden_size), config.torch_dtype + return (1, seq_len, hidden_size), config.torch_dtype def construct_mm_data( model: str, - encode_output: EncodeResponse, image_embeds: torch.Tensor, embeddings_dtype: torch.dtype, + image_grid_thw: Optional[List[Any]], ) -> Dict[str, torch.Tensor | Dict[str, Any]]: """Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings""" image_embeds = image_embeds.to(embeddings_dtype) - if "Qwen2" in model: + if model == SupportedModels.QWEN_2_5_VL_7B: + if image_grid_thw is not None and len(image_grid_thw) > 0: + grid_thw_tensor = torch.tensor(image_grid_thw) + else: + raise ValueError("No image grid provided.") + return { "image": { "image_embeds": image_embeds.squeeze(0), - "image_grid_thw": torch.tensor(encode_output.image_grid_thw).squeeze(0), - } - } - elif "MiniCPM-V" in model: - return { - "image": { - "image_embeds": image_embeds, - "image_sizes": encode_output.image_sizes, + "image_grid_thw": grid_thw_tensor, } } else: diff --git a/examples/multimodal/utils/protocol.py b/examples/multimodal/utils/protocol.py index f5083cc441..15e66f09c0 100644 --- a/examples/multimodal/utils/protocol.py +++ b/examples/multimodal/utils/protocol.py @@ -15,7 +15,7 @@ import json -from typing import Any, List, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Tuple, Union import msgspec from pydantic import BaseModel, ConfigDict, field_validator @@ -127,7 +127,8 @@ class MultiModalRequest(BaseModel): class vLLMMultimodalRequest(vLLMGenerateRequest): model_config = ConfigDict(arbitrary_types_allowed=True) image_url: Optional[str] = None - # image_features: Optional[List[List[List[float]]]] = None # Remove once have NIXL support + image_grid_thw: Optional[List[Any]] = None + embeddings_shape: Optional[Tuple[int, int, int]] = None serialized_request: Optional[connect.RdmaMetadata] = None @@ -142,15 +143,6 @@ class EncodeRequest(BaseModel): serialized_request: Optional[connect.RdmaMetadata] = None -class EncodeResponse(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - request_id: str - image_grid_thw: Optional[List[Any]] = None - image_sizes: Optional[List[Any]] = None - serialized_request: Optional[connect.RdmaMetadata] = None - image_features: List[List[List[float]]] # Remove once have NIXL support - - class MyRequestOutput(BaseModel): """ RequestOutput from vLLM is not serializable by default diff --git a/tests/serve/test_vllm.py b/tests/serve/test_vllm.py index 31ec74ab4e..0d3363c420 100644 --- a/tests/serve/test_vllm.py +++ b/tests/serve/test_vllm.py @@ -166,8 +166,8 @@ def __init__(self, config: VLLMConfig, request): ], timeout=560, ), - "multimodal_agg": VLLMConfig( - name="multimodal_agg", + "multimodal_agg_llava": VLLMConfig( + name="multimodal_agg_llava", directory="/workspace/examples/multimodal", script_name="agg.sh", marks=[pytest.mark.gpu_2, pytest.mark.vllm], @@ -180,6 +180,20 @@ def __init__(self, config: VLLMConfig, request): args=["--model", "llava-hf/llava-1.5-7b-hf"], timeout=360, ), + "multimodal_agg_qwen": VLLMConfig( + name="multimodal_agg_qwen", + directory="/workspace/examples/multimodal", + script_name="agg.sh", + marks=[pytest.mark.gpu_2, pytest.mark.vllm], + endpoints=["v1/chat/completions"], + response_handlers=[ + chat_completions_response_handler, + ], + model="Qwen/Qwen2.5-VL-7B-Instruct", + delayed_start=0, + args=["--model", "Qwen/Qwen2.5-VL-7B-Instruct"], + timeout=360, + ), # TODO: Enable this test case when we have 4 GPUs runners. # "multimodal_disagg": VLLMConfig( # name="multimodal_disagg",