-
Notifications
You must be signed in to change notification settings - Fork 738
feat: Adding nixl read() multimodal support for vLLM backend #4271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
b9f3484
b6ca505
f145c1c
cb0d388
b0221bb
a2687d2
cfee423
43a3f0c
01f94d6
0aabc41
d33d718
94d027d
518e768
76a4683
3089249
e5c495b
90fcf54
8343753
5b89fdc
1d67335
0e2af92
7f41e85
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Signed-off-by: Krishnan Prashanth <[email protected]>
- Loading branch information
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,10 +2,8 @@ | |
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import asyncio | ||
| import base64 | ||
| import logging | ||
| import os | ||
| import zlib | ||
| from abc import ABC, abstractmethod | ||
| from contextlib import asynccontextmanager | ||
| from typing import Any, AsyncGenerator, Dict, Final | ||
|
|
@@ -18,20 +16,12 @@ | |
|
|
||
| import dynamo.nixl_connect as connect | ||
| from dynamo.llm import ZmqKvEventPublisher | ||
| from dynamo.nixl_connect import OperationKind, RdmaMetadata, SerializedDescriptor | ||
| from dynamo.runtime.logging import configure_dynamo_logging | ||
|
|
||
| from .engine_monitor import VllmEngineMonitor | ||
| from .multimodal_utils.image_loader import ImageLoader | ||
|
|
||
| # For constructing RdmaMetadata from Decoded variant | ||
| try: | ||
| from dynamo.nixl_connect import OperationKind, RdmaMetadata, SerializedDescriptor | ||
| except ImportError: | ||
| # If nixl_connect not available, will fail at runtime when Decoded variant encountered | ||
| RdmaMetadata = None | ||
| SerializedDescriptor = None | ||
| OperationKind = None | ||
|
|
||
| # Multimodal data dictionary keys | ||
| IMAGE_URL_KEY: Final = "image_url" | ||
| VIDEO_URL_KEY: Final = "video_url" | ||
|
|
@@ -134,135 +124,53 @@ def cleanup(self): | |
| """Override in subclasses if cleanup is needed.""" | ||
| pass | ||
|
|
||
| async def _ensure_connector_initialized(self): | ||
| """ | ||
| Lazy initialization of NIXL connector. | ||
| Only called when Decoded variant is encountered. | ||
| """ | ||
| if self._connector is None: | ||
| logger.info("Initializing NIXL connector for decoded media support") | ||
| self._connector = connect.Connector() | ||
| await self._connector.initialize() | ||
| logger.info("NIXL connector initialized") | ||
|
|
||
| async def _read_decoded_image_via_nixl( | ||
| self, decoded_meta: Dict[str, Any] | ||
| ) -> Image.Image: | ||
| """ | ||
| Read decoded image data via NIXL RDMA. | ||
|
|
||
| Args: | ||
| decoded_meta: Dictionary containing: | ||
| - nixl_metadata: Base64-encoded NIXL agent metadata | ||
| - nixl_descriptor: {addr, size, mem_type, device_id} | ||
| - shape: [height, width, channels] | ||
| - dtype: Data type (e.g., "UINT8") | ||
| - metadata: Optional image metadata (format, color_type, etc.) | ||
|
|
||
| Returns: | ||
| PIL.Image object | ||
| """ | ||
| # Ensure connector is initialized | ||
| await self._ensure_connector_initialized() | ||
|
|
||
| # Extract and validate required fields | ||
| if ( | ||
| "nixl_metadata" not in decoded_meta | ||
| or "shape" not in decoded_meta | ||
| or "nixl_descriptor" not in decoded_meta | ||
| ): | ||
| raise ValueError( | ||
| f"Decoded variant missing required fields. Got keys: {decoded_meta.keys()}" | ||
| ) | ||
| """Read decoded image via NIXL RDMA and convert to PIL.Image.""" | ||
| # Lazy-init connector | ||
| if self._connector is None: | ||
| self._connector = connect.Connector() | ||
| await self._connector.initialize() | ||
| logger.info("NIXL connector initialized for decoded media") | ||
|
|
||
| nixl_metadata_str = decoded_meta["nixl_metadata"] | ||
| nixl_descriptor = decoded_meta["nixl_descriptor"] | ||
| # Extract fields | ||
| meta_str = decoded_meta["nixl_metadata"] | ||
| desc = decoded_meta["nixl_descriptor"] | ||
| shape = decoded_meta["shape"] | ||
| dtype_str = decoded_meta.get("dtype", "UINT8") | ||
|
|
||
| # Frontend only sends UINT8 for images currently | ||
| if dtype_str != "UINT8": | ||
| raise ValueError( | ||
| f"Unsupported dtype: {dtype_str} (only UINT8 supported for images)" | ||
| ) | ||
|
|
||
| # Create empty tensor to receive RDMA data | ||
| # Shape from frontend is [height, width, channels] | ||
| tensor = torch.empty(shape, dtype=torch.uint8, device="cpu") | ||
| local_descriptor = connect.Descriptor(tensor) | ||
|
|
||
| # Construct RdmaMetadata object from decoded_meta | ||
| # Frontend sends nixl_descriptor with {addr, size, mem_type, device_id} | ||
| # Need to convert to SerializedDescriptor format | ||
| mem_type = nixl_descriptor.get("mem_type", "Dram") | ||
| device_str = ( | ||
| "cpu" | ||
| if mem_type == "Dram" | ||
| else f"cuda:{nixl_descriptor.get('device_id', 0)}" | ||
| ) | ||
|
|
||
| serialized_desc = SerializedDescriptor( | ||
| device=device_str, ptr=nixl_descriptor["addr"], size=nixl_descriptor["size"] | ||
| ) | ||
|
|
||
| # Fix nixl_metadata format issue: | ||
| # Backend expects: "b64:<zlib_compressed_base64>" | ||
| # Frontend sends: "b64:<uncompressed_base64>" (PR #3988 bug) | ||
| # Workaround: Compress if not already compressed | ||
| if nixl_metadata_str.startswith("b64:"): | ||
| # Decode to check if compressed | ||
| try: | ||
| decoded_bytes = base64.b64decode(nixl_metadata_str[4:]) | ||
| # Try to decompress - if it works, already compressed | ||
| try: | ||
| zlib.decompress(decoded_bytes) | ||
| # Already compressed, use as-is | ||
| final_nixl_metadata = nixl_metadata_str | ||
| except zlib.error: | ||
| # Not compressed, need to compress | ||
| compressed = zlib.compress(decoded_bytes, level=6) | ||
| reencoded = base64.b64encode(compressed).decode("utf-8") | ||
| final_nixl_metadata = f"b64:{reencoded}" | ||
| logger.debug("Compressed uncompressed NIXL metadata from frontend") | ||
| except Exception as e: | ||
| raise ValueError(f"Failed to decode nixl_metadata: {e}") | ||
| else: | ||
| final_nixl_metadata = nixl_metadata_str | ||
|
|
||
| rdma_metadata = RdmaMetadata( | ||
| descriptors=[serialized_desc], | ||
| nixl_metadata=final_nixl_metadata, | ||
| notification_key=f"decoded-image-{decoded_meta.get('shape', 'unknown')}", | ||
| # Create tensor to receive RDMA data | ||
| tensor = torch.empty(shape, dtype=torch.uint8) | ||
|
|
||
| # Build RdmaMetadata from frontend-provided descriptor | ||
| # Frontend sends compressed metadata (matches Python nixl_connect) | ||
| rdma_meta = RdmaMetadata( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this work, have you tested it? the "normal flow" is to create a passive operation ( given that this is using an active operation (
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. usually metadata comes from the secondary connection, which in turn got it from its |
||
| descriptors=[ | ||
| SerializedDescriptor( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not pass |
||
| device="cpu" | ||
| if desc.get("mem_type") == "Dram" | ||
| else f"cuda:{desc.get('device_id', 0)}", | ||
| ptr=desc["addr"], | ||
| size=desc["size"], | ||
| ) | ||
| ], | ||
| nixl_metadata=meta_str, | ||
| notification_key=f"img-{shape}", | ||
| operation_kind=int(OperationKind.READ), | ||
| ) | ||
|
|
||
| # Read via NIXL RDMA | ||
| read_op = await self._connector.begin_read(rdma_metadata, local_descriptor) | ||
| # RDMA read | ||
| read_op = await self._connector.begin_read( | ||
| rdma_meta, connect.Descriptor(tensor) | ||
| ) | ||
| await read_op.wait_for_completion() | ||
| logger.debug(f"Loaded image via NIXL RDMA: shape={shape}") | ||
|
|
||
| # Convert tensor to PIL.Image | ||
| # Tensor shape is [H, W, C], dtype is uint8 | ||
| # PIL.Image.fromarray expects numpy array | ||
| numpy_array = tensor.numpy() | ||
|
|
||
| # Determine PIL mode based on number of channels (common cases) | ||
| # Frontend sends 3D array [H, W, C] | ||
| num_channels = shape[2] | ||
| if num_channels == 3: | ||
| mode = "RGB" # Most common | ||
| elif num_channels == 4: | ||
| mode = "RGBA" | ||
| elif num_channels == 1: | ||
| mode = "L" # Grayscale | ||
| numpy_array = numpy_array.squeeze(-1) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported channel count: {num_channels} (expected 1, 3, or 4)" | ||
| ) | ||
|
|
||
| pil_image = Image.fromarray(numpy_array, mode=mode) | ||
| return pil_image | ||
| # Convert to PIL.Image (assume RGB, handle RGBA/grayscale) | ||
| arr = tensor.numpy() | ||
| modes = {1: "L", 3: "RGB", 4: "RGBA"} | ||
| if modes[shape[2]] == "L": | ||
| arr = arr.squeeze(-1) | ||
| return Image.fromarray(arr, modes[shape[2]]) | ||
|
|
||
| async def _extract_multimodal_data( | ||
| self, request: Dict[str, Any] | ||
|
|
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -112,13 +112,23 @@ impl<D: Dimension> TryFrom<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData { | |
| // TODO: pre-allocate a fixed NIXL-registered RAM pool so metadata can be cached on the target? | ||
| #[cfg(feature = "media-nixl")] | ||
| pub fn get_nixl_metadata(agent: &NixlAgent, _storage: &SystemStorage) -> Result<String> { | ||
| use flate2::Compression; | ||
| use flate2::write::ZlibEncoder; | ||
| use std::io::Write; | ||
|
|
||
| // WAR: Until https://github.com/ai-dynamo/nixl/pull/970 is merged, can't use get_local_partial_md | ||
| let nixl_md = agent.raw_agent().get_local_md()?; | ||
| // let mut reg_desc_list = RegDescList::new(MemType::Dram)?; | ||
| // reg_desc_list.add_storage_desc(storage)?; | ||
| // let nixl_partial_md = agent.raw_agent().get_local_partial_md(®_desc_list, None)?; | ||
|
|
||
| let b64_encoded = general_purpose::STANDARD.encode(&nixl_md); | ||
| // Compress metadata before base64 encoding (matches Python nixl_connect behavior) | ||
| // Backend expects: b64:<base64_of_compressed_bytes> | ||
| let mut encoder = ZlibEncoder::new(Vec::new(), Compression::new(6)); | ||
KrishnanPrash marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
KrishnanPrash marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| encoder.write_all(&nixl_md)?; | ||
| let compressed = encoder.finish()?; | ||
|
||
|
|
||
| let b64_encoded = general_purpose::STANDARD.encode(&compressed); | ||
| Ok(format!("b64:{}", b64_encoded)) | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what type is
desc?