Skip to content
6 changes: 5 additions & 1 deletion examples/multimodal/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ flowchart LR
pd_worker --> encode_worker
```

***Note*** Only the LLaVA 1.5 7B model is supported. Qwen2.5-VL and Phi3V support will be added in the future.
***Note*** Aggregated serving supports LLaVA 1.5 7B and Qwen2.5-VL-7B-Instruct today. Phi3V support will be added in the future. Disaggregated serving is currently only confirmed for LLaVA (see note below).

```bash
cd $DYNAMO_HOME/examples/multimodal
# Serve a LLaVA 1.5 7B model:
bash launch/agg.sh --model llava-hf/llava-1.5-7b-hf
# Serve a Qwen2.5-VL model:
bash launch/agg.sh --model Qwen/Qwen2.5-VL-7B-Instruct
```

### Client
Expand Down Expand Up @@ -98,6 +100,8 @@ curl http://localhost:8080/v1/chat/completions \
}'
```

If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`.

You should see a response similar to this:
```json
{"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]}
Expand Down
77 changes: 28 additions & 49 deletions examples/multimodal/components/encode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
import sys
from typing import AsyncIterator, Tuple

import torch
import uvloop
from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser

Expand All @@ -33,7 +32,9 @@

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
from utils.encode_utils import encode_image_embeddings, get_encoder_components
from utils.image_loader import ImageLoader
from utils.model import load_vision_model
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest

configure_dynamo_logging()
Expand Down Expand Up @@ -70,13 +71,14 @@ def __init__(
self.image_processor = AutoImageProcessor.from_pretrained(
self.model, trust_remote_code=True
)
# self.vision_model = load_vision_model(self.model)
self.vision_model = LlavaForConditionalGeneration.from_pretrained(
self.model, device_map="auto", torch_dtype=torch.float16
).eval()

self.vision_model = load_vision_model(self.model)
self.min_workers = 1

# Get encoder components for the model
self.vision_encoder, self.projector = get_encoder_components(
self.model, self.vision_model
)

def cleanup(self):
pass

Expand Down Expand Up @@ -108,49 +110,26 @@ async def generate(

logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt")
# [gluo NOTE] The commented section is for VLM generalization support,
# will use more generic approach once utils/model.py is fixed,
# see utils/models.py for details.
# # Add a batch dimension to everything
# for item in image_embeds:
# image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE)
# logger.debug(f"Image embeds: {image_embeds}")

# image_grid_thw = (
# image_embeds["image_grid_thw"].tolist()
# if "image_grid_thw" in image_embeds
# else None
# )
# image_sizes = (
# image_embeds["image_sizes"].tolist()
# if "image_sizes" in image_embeds
# else [image.size]
# )
# logger.debug(
# f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
# )

# with torch.no_grad():
# embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds)
# if isinstance(embeddings, tuple) or isinstance(embeddings, list):
# # The result multimodal_embeddings may be a list or tuple of tensors, with each
# # tensor corresponding to a multimodal data item (image or video).
# # TODO: for multi-image support, this result will contain multiple tensors.
# embeddings = embeddings[0].unsqueeze(0)
# logger.debug(
# f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
# )

with torch.no_grad():
logger.debug(f"Vision model device: {self.vision_model.device}")
vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device)
)
logger.debug("Vision model completed.")

embeddings = vision_outputs.last_hidden_state
embeddings = self.vision_model.multi_modal_projector(embeddings)

# Encode the image embeddings using model-specific encoder
embeddings = encode_image_embeddings(
model_name=self.model,
image_embeds=image_embeds,
vision_encoder=self.vision_encoder,
projector=self.projector,
)

image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)

request.image_grid_thw = image_grid_thw
request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings)

with self._connector.create_readable(descriptor) as readable:
Expand Down
60 changes: 24 additions & 36 deletions examples/multimodal/components/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import torch
import uvloop
from transformers import AutoImageProcessor
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt
Expand All @@ -47,6 +46,7 @@
parse_endpoint,
)
from utils.image_loader import ImageLoader
from utils.model import construct_mm_data
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest

configure_dynamo_logging()
Expand Down Expand Up @@ -245,37 +245,15 @@ async def async_init(self, runtime: DistributedRuntime):
.client()
)

EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cpu"
self.EMBEDDINGS_DTYPE = torch.float16
self.EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
parsed_namespace, _, _ = parse_endpoint(self.endpoint)
self._connector = connect.Connector()
await self._connector.initialize()

# embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
# self.engine_args.model, self.engine_args.num_patches
# )
# [gluo NOTE] Hardcoded for now, will use more generic approach once utils/model.py
# is fixed, see utils/models.py for details.
embeddings_shape = (1, 577, 4096)
logger.debug(f"Embeddings shape: {embeddings_shape}")
self.embedding_size = embeddings_shape[1]

embeddings = torch.empty(
embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
)

descriptor = connect.Descriptor(embeddings)

# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
# descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)

self.image_loader = ImageLoader()
self.image_processor = AutoImageProcessor.from_pretrained(
self.engine_args.model, trust_remote_code=True
)

logger.info("VllmPDWorker has been initialized")

Expand All @@ -288,10 +266,18 @@ async def generate(self, request: vLLMMultimodalRequest):
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")

if request.image_url is None:
# Process embeddings using the connector
embeddings, descriptor = self._embeddings_descriptor
embeddings, descriptor = None, None

# Process embeddings using the connector
# Create a descriptor based on the embedding shape.
embeddings = torch.empty(
request.embeddings_shape,
dtype=self.EMBEDDINGS_DTYPE,
device=self.EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)

if request.image_url is None:
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
Expand All @@ -301,15 +287,17 @@ async def generate(self, request: vLLMMultimodalRequest):
request.serialized_request, descriptor
)
await read_op.wait_for_completion()
logger.debug(f"in PD worker, image features: {embeddings}")
multi_modal_data = embeddings
multi_modal_data = construct_mm_data(
self.engine_args.model,
embeddings,
self.EMBEDDINGS_DTYPE,
request.image_grid_thw,
)
else:
# Use PIL image instead of image embeddings
multi_modal_data = await self.image_loader.load_image(request.image_url)
# multi_modal_data = self.image_processor(images=image, return_tensors="pt")["pixel_values"].to(dtype=torch.float16)
# image input is expected to be (image_num, channel, height, width)
# logger.info(f"Image features shape: {multi_modal_data.shape}")
# multi_modal_data = multi_modal_data.unsqueeze(0)
multi_modal_data = {
"image": await self.image_loader.load_image(request.image_url)
}

# Remove the image features from the request as they are not required
request.image_url = None
Expand All @@ -331,7 +319,7 @@ async def generate(self, request: vLLMMultimodalRequest):
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"],
multi_modal_data={"image": multi_modal_data},
multi_modal_data=multi_modal_data,
),
sampling_params=pd_request.sampling_params,
request_id=pd_request.request_id,
Expand Down
132 changes: 132 additions & 0 deletions examples/multimodal/utils/encode_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any, Dict, Optional

import torch

from .model import SupportedModels

logger = logging.getLogger(__name__)


def get_qwen_image_features(
vision_encoder: torch.nn.Module, image_embeds: Dict[str, Any]
) -> torch.Tensor:
"""
Extract image features using Qwen-style vision encoder.

Args:
vision_encoder: The vision encoder model
image_embeds: Dictionary containing pixel values and grid information

Returns:
Processed image features tensor

Raises:
ValueError: If grid_thw is not provided for Qwen model
"""
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)

grid_thw = image_embeds.get("image_grid_thw", None)
if grid_thw is not None:
grid_thw = grid_thw.to(vision_encoder.device)
logger.debug(f"Qwen grid_thw shape: {grid_thw.shape}")
else:
raise ValueError("grid_thw is not provided")

return (
vision_encoder.get_image_features(pixel_values, grid_thw) # type: ignore
if grid_thw is not None
else vision_encoder.get_image_features(pixel_values) # type: ignore
)


def encode_image_embeddings(
model_name: str,
image_embeds: Dict[str, Any],
vision_encoder: torch.nn.Module,
projector: Optional[torch.nn.Module] = None,
) -> torch.Tensor:
"""
Encode image embeddings using the appropriate model-specific encoder.

Args:
model_name: The model identifier
image_embeds: Dictionary containing processed image data
vision_encoder: The vision encoder module
projector: The multimodal projector (required for LLaVA-style models)

Returns:
Encoded embeddings tensor with normalized shape

Raises:
ValueError: If projector is missing for LLaVA models
NotImplementedError: If model is not supported
"""
with torch.no_grad():
# Route through the correct encoder based on model
if model_name == SupportedModels.LLAVA_1_5_7B:
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
vision_outputs = vision_encoder(pixel_values)

if projector is None:
raise ValueError(f"Projector not found for LLaVA model: {model_name}")

embeddings = projector(vision_outputs.last_hidden_state)

elif model_name == SupportedModels.QWEN_2_5_VL_7B:
embeddings = get_qwen_image_features(vision_encoder, image_embeds)

else:
raise NotImplementedError(f"Model not supported: {model_name}")

# Normalize output shape
if isinstance(embeddings, (tuple, list)):
embeddings = embeddings[0]
embeddings = embeddings.unsqueeze(0) if embeddings.ndim == 2 else embeddings

return embeddings


def get_encoder_components(
model_name: str, vision_model: torch.nn.Module
) -> tuple[Any, Optional[Any]]:
"""
Get the appropriate vision encoder and projector components for a given model.

Args:
model_name: The model identifier
vision_model: The loaded vision model

Returns:
Tuple of (vision_encoder, projector) where types depend on the model

Raises:
NotImplementedError: If model is not supported
"""
if model_name == SupportedModels.LLAVA_1_5_7B:
vision_encoder = vision_model.vision_tower
projector = getattr(vision_model, "multi_modal_projector", None)
return vision_encoder, projector

elif model_name == SupportedModels.QWEN_2_5_VL_7B:
vision_encoder = vision_model
projector = None
return vision_encoder, projector

else:
raise NotImplementedError(f"Model not supported: {model_name}")
4 changes: 3 additions & 1 deletion examples/multimodal/utils/image_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ class ImageLoader:

def __init__(self, cache_size: int = CACHE_SIZE_MAXIMUM):
self._http_timeout = 30.0
self._http_client = httpx.AsyncClient(timeout=self._http_timeout)
self._http_client = httpx.AsyncClient(
timeout=self._http_timeout, follow_redirects=True
)
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)

Expand Down
Loading
Loading