Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Simplified workflow
Signed-off-by: Krishnan Prashanth <[email protected]>
  • Loading branch information
KrishnanPrash committed Nov 12, 2025
commit 76a4683fa5524e033e76fbb5eb5cc5a346ecef29
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

166 changes: 37 additions & 129 deletions components/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
import base64
import logging
import os
import zlib
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final
Expand All @@ -18,20 +16,12 @@

import dynamo.nixl_connect as connect
from dynamo.llm import ZmqKvEventPublisher
from dynamo.nixl_connect import OperationKind, RdmaMetadata, SerializedDescriptor
from dynamo.runtime.logging import configure_dynamo_logging

from .engine_monitor import VllmEngineMonitor
from .multimodal_utils.image_loader import ImageLoader

# For constructing RdmaMetadata from Decoded variant
try:
from dynamo.nixl_connect import OperationKind, RdmaMetadata, SerializedDescriptor
except ImportError:
# If nixl_connect not available, will fail at runtime when Decoded variant encountered
RdmaMetadata = None
SerializedDescriptor = None
OperationKind = None

# Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url"
VIDEO_URL_KEY: Final = "video_url"
Expand Down Expand Up @@ -134,135 +124,53 @@ def cleanup(self):
"""Override in subclasses if cleanup is needed."""
pass

async def _ensure_connector_initialized(self):
"""
Lazy initialization of NIXL connector.
Only called when Decoded variant is encountered.
"""
if self._connector is None:
logger.info("Initializing NIXL connector for decoded media support")
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("NIXL connector initialized")

async def _read_decoded_image_via_nixl(
self, decoded_meta: Dict[str, Any]
) -> Image.Image:
"""
Read decoded image data via NIXL RDMA.

Args:
decoded_meta: Dictionary containing:
- nixl_metadata: Base64-encoded NIXL agent metadata
- nixl_descriptor: {addr, size, mem_type, device_id}
- shape: [height, width, channels]
- dtype: Data type (e.g., "UINT8")
- metadata: Optional image metadata (format, color_type, etc.)

Returns:
PIL.Image object
"""
# Ensure connector is initialized
await self._ensure_connector_initialized()

# Extract and validate required fields
if (
"nixl_metadata" not in decoded_meta
or "shape" not in decoded_meta
or "nixl_descriptor" not in decoded_meta
):
raise ValueError(
f"Decoded variant missing required fields. Got keys: {decoded_meta.keys()}"
)
"""Read decoded image via NIXL RDMA and convert to PIL.Image."""
# Lazy-init connector
if self._connector is None:
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("NIXL connector initialized for decoded media")

nixl_metadata_str = decoded_meta["nixl_metadata"]
nixl_descriptor = decoded_meta["nixl_descriptor"]
# Extract fields
meta_str = decoded_meta["nixl_metadata"]
desc = decoded_meta["nixl_descriptor"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what type is desc?

shape = decoded_meta["shape"]
dtype_str = decoded_meta.get("dtype", "UINT8")

# Frontend only sends UINT8 for images currently
if dtype_str != "UINT8":
raise ValueError(
f"Unsupported dtype: {dtype_str} (only UINT8 supported for images)"
)

# Create empty tensor to receive RDMA data
# Shape from frontend is [height, width, channels]
tensor = torch.empty(shape, dtype=torch.uint8, device="cpu")
local_descriptor = connect.Descriptor(tensor)

# Construct RdmaMetadata object from decoded_meta
# Frontend sends nixl_descriptor with {addr, size, mem_type, device_id}
# Need to convert to SerializedDescriptor format
mem_type = nixl_descriptor.get("mem_type", "Dram")
device_str = (
"cpu"
if mem_type == "Dram"
else f"cuda:{nixl_descriptor.get('device_id', 0)}"
)

serialized_desc = SerializedDescriptor(
device=device_str, ptr=nixl_descriptor["addr"], size=nixl_descriptor["size"]
)

# Fix nixl_metadata format issue:
# Backend expects: "b64:<zlib_compressed_base64>"
# Frontend sends: "b64:<uncompressed_base64>" (PR #3988 bug)
# Workaround: Compress if not already compressed
if nixl_metadata_str.startswith("b64:"):
# Decode to check if compressed
try:
decoded_bytes = base64.b64decode(nixl_metadata_str[4:])
# Try to decompress - if it works, already compressed
try:
zlib.decompress(decoded_bytes)
# Already compressed, use as-is
final_nixl_metadata = nixl_metadata_str
except zlib.error:
# Not compressed, need to compress
compressed = zlib.compress(decoded_bytes, level=6)
reencoded = base64.b64encode(compressed).decode("utf-8")
final_nixl_metadata = f"b64:{reencoded}"
logger.debug("Compressed uncompressed NIXL metadata from frontend")
except Exception as e:
raise ValueError(f"Failed to decode nixl_metadata: {e}")
else:
final_nixl_metadata = nixl_metadata_str

rdma_metadata = RdmaMetadata(
descriptors=[serialized_desc],
nixl_metadata=final_nixl_metadata,
notification_key=f"decoded-image-{decoded_meta.get('shape', 'unknown')}",
# Create tensor to receive RDMA data
tensor = torch.empty(shape, dtype=torch.uint8)

# Build RdmaMetadata from frontend-provided descriptor
# Frontend sends compressed metadata (matches Python nixl_connect)
rdma_meta = RdmaMetadata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work, have you tested it?

the "normal flow" is to create a passive operation (ReadableOp or WritableOp) and use their .metadata property to get the set of SerializedDescriptors and not manually compose this.

given that this is using an active operation (ReadOp) it should be taking in the metadata to perform the read, not sending the metadata.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually metadata comes from the secondary connection, which in turn got it from its ReadableOperation.

descriptors=[
SerializedDescriptor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not pass desc to a Descriptor and then serialize the descriptor to get the metadata?

device="cpu"
if desc.get("mem_type") == "Dram"
else f"cuda:{desc.get('device_id', 0)}",
ptr=desc["addr"],
size=desc["size"],
)
],
nixl_metadata=meta_str,
notification_key=f"img-{shape}",
operation_kind=int(OperationKind.READ),
)

# Read via NIXL RDMA
read_op = await self._connector.begin_read(rdma_metadata, local_descriptor)
# RDMA read
read_op = await self._connector.begin_read(
rdma_meta, connect.Descriptor(tensor)
)
await read_op.wait_for_completion()
logger.debug(f"Loaded image via NIXL RDMA: shape={shape}")

# Convert tensor to PIL.Image
# Tensor shape is [H, W, C], dtype is uint8
# PIL.Image.fromarray expects numpy array
numpy_array = tensor.numpy()

# Determine PIL mode based on number of channels (common cases)
# Frontend sends 3D array [H, W, C]
num_channels = shape[2]
if num_channels == 3:
mode = "RGB" # Most common
elif num_channels == 4:
mode = "RGBA"
elif num_channels == 1:
mode = "L" # Grayscale
numpy_array = numpy_array.squeeze(-1)
else:
raise ValueError(
f"Unsupported channel count: {num_channels} (expected 1, 3, or 4)"
)

pil_image = Image.fromarray(numpy_array, mode=mode)
return pil_image
# Convert to PIL.Image (assume RGB, handle RGBA/grayscale)
arr = tensor.numpy()
modes = {1: "L", 3: "RGB", 4: "RGBA"}
if modes[shape[2]] == "L":
arr = arr.squeeze(-1)
return Image.fromarray(arr, modes[shape[2]])

async def _extract_multimodal_data(
self, request: Dict[str, Any]
Expand Down
24 changes: 5 additions & 19 deletions components/src/dynamo/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,23 +300,9 @@ async def register_vllm_model(
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
runtime_config.data_parallel_size = data_parallel_size

# Configure media decoder for frontend image decoding (PR #3988)
# This enables frontend to decode images and transfer via NIXL RDMA
media_decoder = MediaDecoder()
media_decoder.image_decoder(
{
"max_image_width": 4096,
"max_image_height": 4096,
"max_alloc": 128 * 1024 * 1024, # 128MB
}
)

media_fetcher = MediaFetcher()
# Security: Only allow standard schemes, no direct IPs
media_fetcher.allow_direct_ip(False)
media_fetcher.allow_direct_port(False)
media_fetcher.timeout_ms(30000) # 30s timeout

# Enable frontend RDMA decoding with default settings
# MediaDecoder defaults: 128MB limit, sensible image size limits
# MediaFetcher defaults: 30s timeout, secure (no direct IP/port)
await register_llm(
model_input,
model_type,
Expand All @@ -327,8 +313,8 @@ async def register_vllm_model(
migration_limit=migration_limit,
runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template,
media_decoder=media_decoder,
media_fetcher=media_fetcher,
media_decoder=MediaDecoder(),
media_fetcher=MediaFetcher(),
)


Expand Down
1 change: 1 addition & 0 deletions lib/bindings/python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions lib/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ base64 = { version = "0.22" }
image = { version = "0.25", features = ["serde"] }
tokio-rayon = {version = "2" }
ndarray = { version = "0.16" }
flate2 = { version = "1.0" }

# Publishers
zeromq = "0.4.1"
Expand Down
12 changes: 11 additions & 1 deletion lib/llm/src/preprocessor/media/rdma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,23 @@ impl<D: Dimension> TryFrom<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
// TODO: pre-allocate a fixed NIXL-registered RAM pool so metadata can be cached on the target?
#[cfg(feature = "media-nixl")]
pub fn get_nixl_metadata(agent: &NixlAgent, _storage: &SystemStorage) -> Result<String> {
use flate2::Compression;
use flate2::write::ZlibEncoder;
use std::io::Write;

// WAR: Until https://github.com/ai-dynamo/nixl/pull/970 is merged, can't use get_local_partial_md
let nixl_md = agent.raw_agent().get_local_md()?;
// let mut reg_desc_list = RegDescList::new(MemType::Dram)?;
// reg_desc_list.add_storage_desc(storage)?;
// let nixl_partial_md = agent.raw_agent().get_local_partial_md(&reg_desc_list, None)?;

let b64_encoded = general_purpose::STANDARD.encode(&nixl_md);
// Compress metadata before base64 encoding (matches Python nixl_connect behavior)
// Backend expects: b64:<base64_of_compressed_bytes>
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::new(6));
encoder.write_all(&nixl_md)?;
let compressed = encoder.finish()?;
Copy link
Contributor Author

@KrishnanPrash KrishnanPrash Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once again, welcome any suggestions on correct nixl usage.


let b64_encoded = general_purpose::STANDARD.encode(&compressed);
Ok(format!("b64:{}", b64_encoded))
}

Expand Down
Loading