Skip to content
6 changes: 5 additions & 1 deletion examples/multimodal/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ flowchart LR
pd_worker --> encode_worker
```

***Note*** Only the LLaVA 1.5 7B model is supported. Qwen2.5-VL and Phi3V support will be added in the future.
***Note*** Only the LLaVA 1.5 7B model is supported. Phi3V support will be added in the future.

```bash
cd $DYNAMO_HOME/examples/multimodal
# Serve a LLaVA 1.5 7B model:
bash launch/agg.sh --model llava-hf/llava-1.5-7b-hf
# Serve a Qwen2.5-VL model:
bash launch/agg.sh --model Qwen/Qwen2.5-VL-7B-Instruct
```

### Client
Expand Down Expand Up @@ -98,6 +100,8 @@ curl http://localhost:8080/v1/chat/completions \
}'
```

If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`.

You should see a response similar to this:
```json
{"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]}
Expand Down
105 changes: 60 additions & 45 deletions examples/multimodal/components/encode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torch
import uvloop
from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser

Expand All @@ -34,6 +34,7 @@
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
from utils.image_loader import ImageLoader
from utils.model import load_vision_model
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest

configure_dynamo_logging()
Expand Down Expand Up @@ -70,16 +71,39 @@ def __init__(
self.image_processor = AutoImageProcessor.from_pretrained(
self.model, trust_remote_code=True
)
# self.vision_model = load_vision_model(self.model)
self.vision_model = LlavaForConditionalGeneration.from_pretrained(
self.model, device_map="auto", torch_dtype=torch.float16
).eval()

self.vision_model = load_vision_model(self.model)
self.min_workers = 1

# Detect vision encoder and projector for different model
if "llava" in self.model.lower():
self.vision_encoder = self.vision_model.vision_tower
self.projector = getattr(self.vision_model, "multi_modal_projector", None)

elif "qwen" in self.model.lower():
self.vision_encoder = self.vision_model
self.projector = None
else:
raise NotImplementedError(f"Model not supported: {self.model}")

def cleanup(self):
pass

def get_qwen_image_features(self, vision_encoder, image_embeds):
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)

grid_thw = image_embeds.get("image_grid_thw", None)
if grid_thw is not None:
grid_thw = grid_thw.to(vision_encoder.device)
logger.debug(f"Qwen grid_thw shape: {grid_thw.shape}")
else:
raise ValueError("grid_thw is not provided")

return (
vision_encoder.get_image_features(pixel_values, grid_thw)
if grid_thw is not None
else vision_encoder.get_image_features(pixel_values)
)

async def generate(
self, request: vLLMMultimodalRequest
) -> AsyncIterator[MyRequestOutput]:
Expand Down Expand Up @@ -108,49 +132,40 @@ async def generate(

logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt")
# [gluo NOTE] The commented section is for VLM generalization support,
# will use more generic approach once utils/model.py is fixed,
# see utils/models.py for details.
# # Add a batch dimension to everything
# for item in image_embeds:
# image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE)
# logger.debug(f"Image embeds: {image_embeds}")

# image_grid_thw = (
# image_embeds["image_grid_thw"].tolist()
# if "image_grid_thw" in image_embeds
# else None
# )
# image_sizes = (
# image_embeds["image_sizes"].tolist()
# if "image_sizes" in image_embeds
# else [image.size]
# )
# logger.debug(
# f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
# )

# with torch.no_grad():
# embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds)
# if isinstance(embeddings, tuple) or isinstance(embeddings, list):
# # The result multimodal_embeddings may be a list or tuple of tensors, with each
# # tensor corresponding to a multimodal data item (image or video).
# # TODO: for multi-image support, this result will contain multiple tensors.
# embeddings = embeddings[0].unsqueeze(0)
# logger.debug(
# f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
# )

with torch.no_grad():
logger.debug(f"Vision model device: {self.vision_model.device}")
vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device)
# Route through the correct encoder
if "llava" in self.model.lower():
pixel_values = image_embeds["pixel_values"].to(
self.vision_encoder.device
)
vision_outputs = self.vision_encoder(pixel_values)
embeddings = self.projector(vision_outputs.last_hidden_state)
elif "qwen" in self.model.lower():
embeddings = self.get_qwen_image_features(
self.vision_encoder, image_embeds
)
else:
raise NotImplementedError(f"Model not supported: {self.model}")

# Normalize output shape
if isinstance(embeddings, (tuple, list)):
embeddings = embeddings[0]
embeddings = (
embeddings.unsqueeze(0) if embeddings.ndim == 2 else embeddings
)
logger.debug("Vision model completed.")

embeddings = vision_outputs.last_hidden_state
embeddings = self.vision_model.multi_modal_projector(embeddings)

image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)

request.image_grid_thw = image_grid_thw
request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings)

with self._connector.create_readable(descriptor) as readable:
Expand Down
70 changes: 39 additions & 31 deletions examples/multimodal/components/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import torch
import uvloop
from transformers import AutoImageProcessor
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt
Expand All @@ -47,6 +46,7 @@
parse_endpoint,
)
from utils.image_loader import ImageLoader
from utils.model import construct_mm_data, get_vision_embeddings_info
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest

configure_dynamo_logging()
Expand Down Expand Up @@ -245,37 +245,32 @@ async def async_init(self, runtime: DistributedRuntime):
.client()
)

EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cpu"
self.EMBEDDINGS_DTYPE = torch.float16
self.EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
parsed_namespace, _, _ = parse_endpoint(self.endpoint)
self._connector = connect.Connector()
await self._connector.initialize()

# embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
# self.engine_args.model, self.engine_args.num_patches
# )
# [gluo NOTE] Hardcoded for now, will use more generic approach once utils/model.py
# is fixed, see utils/models.py for details.
embeddings_shape = (1, 577, 4096)
logger.debug(f"Embeddings shape: {embeddings_shape}")
self.embedding_size = embeddings_shape[1]

embeddings = torch.empty(
embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
self.embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
self.engine_args.model
)

descriptor = connect.Descriptor(embeddings)
logger.debug(f"Embeddings shape: {self.embeddings_shape}")
self._embeddings_descriptor = None

# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
# descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
if self.embeddings_shape[1] != 0:
embeddings = torch.empty(
self.embeddings_shape,
dtype=self.EMBEDDINGS_DTYPE,
device=self.EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
self._embeddings_descriptor = (embeddings, descriptor)

self.image_loader = ImageLoader()
self.image_processor = AutoImageProcessor.from_pretrained(
self.engine_args.model, trust_remote_code=True
)

logger.info("VllmPDWorker has been initialized")

Expand All @@ -288,10 +283,21 @@ async def generate(self, request: vLLMMultimodalRequest):
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")

if request.image_url is None:
# Process embeddings using the connector
embeddings, descriptor = None, None

# Process embeddings using the connector
if self._embeddings_descriptor:
embeddings, descriptor = self._embeddings_descriptor
else:
# If no descriptor is provided, create a new one based on the embedding shape.
embeddings = torch.empty(
request.embeddings_shape,
dtype=self.EMBEDDINGS_DTYPE,
device=self.EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)

if request.image_url is None:
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
Expand All @@ -301,15 +307,17 @@ async def generate(self, request: vLLMMultimodalRequest):
request.serialized_request, descriptor
)
await read_op.wait_for_completion()
logger.debug(f"in PD worker, image features: {embeddings}")
multi_modal_data = embeddings
multi_modal_data = construct_mm_data(
self.engine_args.model,
embeddings,
self.EMBEDDINGS_DTYPE,
request.image_grid_thw,
)
else:
# Use PIL image instead of image embeddings
multi_modal_data = await self.image_loader.load_image(request.image_url)
# multi_modal_data = self.image_processor(images=image, return_tensors="pt")["pixel_values"].to(dtype=torch.float16)
# image input is expected to be (image_num, channel, height, width)
# logger.info(f"Image features shape: {multi_modal_data.shape}")
# multi_modal_data = multi_modal_data.unsqueeze(0)
multi_modal_data = {
"image": await self.image_loader.load_image(request.image_url)
}

# Remove the image features from the request as they are not required
request.image_url = None
Expand All @@ -331,7 +339,7 @@ async def generate(self, request: vLLMMultimodalRequest):
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"],
multi_modal_data={"image": multi_modal_data},
multi_modal_data=multi_modal_data,
),
sampling_params=pd_request.sampling_params,
request_id=pd_request.request_id,
Expand Down
Loading
Loading