Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

85 changes: 70 additions & 15 deletions components/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final

import PIL
import torch
from vllm.inputs import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.exceptions import EngineDeadError

import dynamo.nixl_connect as connect
from dynamo.llm import ZmqKvEventPublisher
from dynamo.nixl_connect import OperationKind, RdmaMetadata, SerializedDescriptor
from dynamo.runtime.logging import configure_dynamo_logging

from .engine_monitor import VllmEngineMonitor
Expand Down Expand Up @@ -73,6 +77,7 @@ def __init__(self, runtime, component, engine, default_sampling_params):
self.kv_publishers: list[ZmqKvEventPublisher] | None = None
self.engine_monitor = VllmEngineMonitor(runtime, engine)
self.image_loader = ImageLoader()
self._connector = None # Lazy-initialized on first Decoded variant

@abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
Expand Down Expand Up @@ -119,11 +124,63 @@ def cleanup(self):
"""Override in subclasses if cleanup is needed."""
pass

async def _read_decoded_image_via_nixl(
self, decoded_meta: Dict[str, Any]
) -> PIL.Image.Image:
"""Read decoded image via NIXL RDMA and convert to PIL.Image."""
# Lazy-init connector
if self._connector is None:
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("NIXL connector initialized for decoded media")

# Extract fields
meta_str = decoded_meta["nixl_metadata"]
desc = decoded_meta["nixl_descriptor"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what type is desc?

shape = decoded_meta["shape"]

# Create tensor to receive RDMA data
tensor = torch.empty(shape, dtype=torch.uint8)

# Build RdmaMetadata from frontend-provided descriptor
# Frontend sends compressed metadata (matches Python nixl_connect)
rdma_meta = RdmaMetadata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work, have you tested it?

the "normal flow" is to create a passive operation (ReadableOp or WritableOp) and use their .metadata property to get the set of SerializedDescriptors and not manually compose this.

given that this is using an active operation (ReadOp) it should be taking in the metadata to perform the read, not sending the metadata.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually metadata comes from the secondary connection, which in turn got it from its ReadableOperation.

descriptors=[
SerializedDescriptor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not pass desc to a Descriptor and then serialize the descriptor to get the metadata?

device="cpu"
if desc.get("mem_type") == "Dram"
else f"cuda:{desc.get('device_id', 0)}",
ptr=desc["addr"],
size=desc["size"],
)
],
nixl_metadata=meta_str,
notification_key=f"img-{shape}",
operation_kind=int(OperationKind.READ),
)

# RDMA read
read_op = await self._connector.begin_read(
rdma_meta, connect.Descriptor(tensor)
)
await read_op.wait_for_completion()
Comment on lines +158 to +197
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a NIXL expert, so please let me know if I can be doing anything here better.


# Convert to PIL.Image (assume RGB, handle RGBA/grayscale)
arr = tensor.numpy()
modes = {1: "L", 3: "RGB", 4: "RGBA"}
if modes[shape[2]] == "L":
arr = arr.squeeze(-1)
return PIL.Image.fromarray(arr, modes[shape[2]])

async def _extract_multimodal_data(
self, request: Dict[str, Any]
) -> Dict[str, Any] | None:
"""
Extract and decode multimodal data from PreprocessedRequest.

Supports two variants:
1. Url: Frontend passes URL, backend decodes
2. Decoded: Frontend decoded, NIXL RDMA transfer
"""
if "multi_modal_data" not in request or request["multi_modal_data"] is None:
return None
Expand All @@ -134,22 +191,20 @@ async def _extract_multimodal_data(
# Process image_url entries
images = []
for item in mm_map.get(IMAGE_URL_KEY, []):
if isinstance(item, dict) and URL_VARIANT_KEY in item:
if isinstance(item, dict) and DECODED_VARIANT_KEY in item:
decoded_meta = item[DECODED_VARIANT_KEY]
image = await self._read_decoded_image_via_nixl(decoded_meta)
images.append(image)
logger.info(
f"Using DECODED path: Loaded image via NIXL RDMA "
f"(shape={decoded_meta.get('shape')}, dtype={decoded_meta.get('dtype')})"
)
elif isinstance(item, dict) and URL_VARIANT_KEY in item:
url = item[URL_VARIANT_KEY]
try:
# ImageLoader supports both data: and http(s): URLs with caching
image = await self.image_loader.load_image(url)
images.append(image)
logger.debug(f"Loaded image from URL: {url[:80]}...")
except Exception:
logger.exception(f"Failed to load image from {url[:80]}...")
raise
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
# Decoded support from PRs #3971/#3988 (frontend decoding + NIXL transfer)
# Will contain NIXL metadata for direct memory access
# TODO: Implement NIXL read when PRs merge
logger.warning(
"Decoded multimodal data not yet supported in standard worker"
image = await self.image_loader.load_image(url)
images.append(image)
logger.info(
f"Using URL path: Loaded image from URL (type={url.split(':')[0]})"
)

if images:
Expand Down
4 changes: 4 additions & 0 deletions components/src/dynamo/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from dynamo.common.config_dump import dump_config
from dynamo.common.utils.prometheus import register_engine_metrics_callback
from dynamo.llm import (
MediaDecoder,
MediaFetcher,
ModelInput,
ModelRuntimeConfig,
ModelType,
Expand Down Expand Up @@ -308,6 +310,8 @@ async def register_vllm_model(
migration_limit=migration_limit,
runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template,
media_decoder=MediaDecoder(),
media_fetcher=MediaFetcher(),
)


Expand Down
Loading
Loading