Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions components/src/dynamo/vllm/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Config:
multimodal_decode_worker: bool = False
multimodal_encode_prefill_worker: bool = False
mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:"
frontend_decoding: bool = False
# dump config to file
dump_config_to: Optional[str] = None

Expand Down Expand Up @@ -175,6 +176,16 @@ def parse_args() -> Config:
"'USER: <image> please describe the image ASSISTANT:'."
),
)
parser.add_argument(
"--frontend-decoding",
action="store_true",
help=(
"EXPERIMENTAL: Enable frontend decoding of multimodal images. "
"When enabled, images are decoded in the Rust frontend and transferred to the backend via NIXL RDMA. "
"Requires building Dynamo's Rust components with '--features media-nixl'. "
"Without this flag, images are decoded in the Python backend (default behavior)."
),
)
parser.add_argument(
"--store-kv",
type=str,
Expand Down
85 changes: 70 additions & 15 deletions components/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final

import PIL
import torch
from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.exceptions import EngineDeadError

import dynamo.nixl_connect as connect
from dynamo.llm import ZmqKvEventPublisher
from dynamo.nixl_connect import OperationKind, RdmaMetadata, SerializedDescriptor
from dynamo.runtime.logging import configure_dynamo_logging

from .engine_monitor import VllmEngineMonitor
Expand Down Expand Up @@ -93,6 +97,7 @@ def __init__(
self.kv_publishers: list[ZmqKvEventPublisher] | None = None
self.engine_monitor = VllmEngineMonitor(runtime, engine)
self.image_loader = ImageLoader()
self._connector = None # Lazy-initialized on first Decoded variant
self.temp_dirs: list[tempfile.TemporaryDirectory] = []
self.model_max_len = model_max_len

Expand Down Expand Up @@ -150,11 +155,63 @@ def cleanup(self):
except Exception as e:
logger.warning(f"Failed to clean up temp directory: {e}")

async def _read_decoded_image_via_nixl(
self, decoded_meta: Dict[str, Any]
) -> PIL.Image.Image:
"""Read decoded image via NIXL RDMA and convert to PIL.Image."""
# Lazy-init connector
if self._connector is None:
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("NIXL connector initialized for decoded media")

# Extract fields
meta_str = decoded_meta["nixl_metadata"]
desc = decoded_meta["nixl_descriptor"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what type is desc?

shape = decoded_meta["shape"]

# Create tensor to receive RDMA data
tensor = torch.empty(shape, dtype=torch.uint8)

# Build RdmaMetadata from frontend-provided descriptor
# Frontend sends compressed metadata (matches Python nixl_connect)
rdma_meta = RdmaMetadata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work, have you tested it?

the "normal flow" is to create a passive operation (ReadableOp or WritableOp) and use their .metadata property to get the set of SerializedDescriptors and not manually compose this.

given that this is using an active operation (ReadOp) it should be taking in the metadata to perform the read, not sending the metadata.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually metadata comes from the secondary connection, which in turn got it from its ReadableOperation.

descriptors=[
SerializedDescriptor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not pass desc to a Descriptor and then serialize the descriptor to get the metadata?

device="cpu"
if desc.get("mem_type") == "Dram"
else f"cuda:{desc.get('device_id', 0)}",
ptr=desc["addr"],
size=desc["size"],
)
],
nixl_metadata=meta_str,
notification_key=f"img-{shape}",
operation_kind=int(OperationKind.READ),
)

# RDMA read
read_op = await self._connector.begin_read(
rdma_meta, connect.Descriptor(tensor)
)
await read_op.wait_for_completion()
Comment on lines +158 to +197
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a NIXL expert, so please let me know if I can be doing anything here better.


# Convert to PIL.Image (assume RGB, handle RGBA/grayscale)
arr = tensor.numpy()
modes = {1: "L", 3: "RGB", 4: "RGBA"}
if modes[shape[2]] == "L":
arr = arr.squeeze(-1)
return PIL.Image.fromarray(arr, modes[shape[2]])

async def _extract_multimodal_data(
self, request: Dict[str, Any]
) -> Dict[str, Any] | None:
"""
Extract and decode multimodal data from PreprocessedRequest.

Supports two variants:
1. Url: Frontend passes URL, backend decodes
2. Decoded: Frontend decoded, NIXL RDMA transfer
"""
if "multi_modal_data" not in request or request["multi_modal_data"] is None:
return None
Expand All @@ -165,22 +222,20 @@ async def _extract_multimodal_data(
# Process image_url entries
images = []
for item in mm_map.get(IMAGE_URL_KEY, []):
if isinstance(item, dict) and URL_VARIANT_KEY in item:
if isinstance(item, dict) and DECODED_VARIANT_KEY in item:
decoded_meta = item[DECODED_VARIANT_KEY]
image = await self._read_decoded_image_via_nixl(decoded_meta)
images.append(image)
logger.info(
f"Using DECODED path: Loaded image via NIXL RDMA "
f"(shape={decoded_meta.get('shape')}, dtype={decoded_meta.get('dtype')})"
)
elif isinstance(item, dict) and URL_VARIANT_KEY in item:
url = item[URL_VARIANT_KEY]
try:
# ImageLoader supports both data: and http(s): URLs with caching
image = await self.image_loader.load_image(url)
images.append(image)
logger.debug(f"Loaded image from URL: {url[:80]}...")
except Exception:
logger.exception(f"Failed to load image from {url[:80]}...")
raise
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
# Decoded support from PRs #3971/#3988 (frontend decoding + NIXL transfer)
# Will contain NIXL metadata for direct memory access
# TODO: Implement NIXL read when PRs merge
logger.warning(
"Decoded multimodal data not yet supported in standard worker"
image = await self.image_loader.load_image(url)
images.append(image)
logger.info(
f"Using URL path: Loaded image from URL (type={url.split(':')[0]})"
)

if images:
Expand Down
22 changes: 22 additions & 0 deletions components/src/dynamo/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,26 @@ async def register_vllm_model(
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
runtime_config.data_parallel_size = data_parallel_size

# Conditionally enable frontend decoding if --frontend-decoding flag is set
media_decoder = None
media_fetcher = None
if config.frontend_decoding:
try:
from dynamo.llm import MediaDecoder, MediaFetcher

media_decoder = MediaDecoder()
media_fetcher = MediaFetcher()
logger.info(
"Frontend decoding enabled: images will be decoded in Rust frontend "
"and transferred via NIXL RDMA"
)
except ImportError as e:
raise RuntimeError(
"Frontend decoding (--frontend-decoding) requires building Dynamo's "
"Rust components with '--features media-nixl'. "
f"Import failed: {e}"
) from e

await register_llm(
model_input,
model_type,
Expand All @@ -319,6 +339,8 @@ async def register_vllm_model(
migration_limit=migration_limit,
runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template,
media_decoder=media_decoder,
media_fetcher=media_fetcher,
)


Expand Down
Loading
Loading