-
Notifications
You must be signed in to change notification settings - Fork 738
feat: Adding nixl read() multimodal support for vLLM backend #4271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b9f3484
b6ca505
f145c1c
cb0d388
b0221bb
a2687d2
cfee423
43a3f0c
01f94d6
0aabc41
d33d718
94d027d
518e768
76a4683
3089249
e5c495b
90fcf54
8343753
5b89fdc
1d67335
0e2af92
7f41e85
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,12 +9,16 @@ | |
| from contextlib import asynccontextmanager | ||
| from typing import Any, AsyncGenerator, Dict, Final | ||
|
|
||
| import PIL | ||
| import torch | ||
| from vllm.inputs import TokensPrompt | ||
| from vllm.outputs import RequestOutput | ||
| from vllm.sampling_params import SamplingParams | ||
| from vllm.v1.engine.exceptions import EngineDeadError | ||
|
|
||
| import dynamo.nixl_connect as connect | ||
| from dynamo.llm import ZmqKvEventPublisher | ||
| from dynamo.nixl_connect import OperationKind, RdmaMetadata, SerializedDescriptor | ||
| from dynamo.runtime.logging import configure_dynamo_logging | ||
|
|
||
| from .engine_monitor import VllmEngineMonitor | ||
|
|
@@ -93,6 +97,7 @@ def __init__( | |
| self.kv_publishers: list[ZmqKvEventPublisher] | None = None | ||
| self.engine_monitor = VllmEngineMonitor(runtime, engine) | ||
| self.image_loader = ImageLoader() | ||
| self._connector = None # Lazy-initialized on first Decoded variant | ||
| self.temp_dirs: list[tempfile.TemporaryDirectory] = [] | ||
| self.model_max_len = model_max_len | ||
|
|
||
|
|
@@ -150,11 +155,63 @@ def cleanup(self): | |
| except Exception as e: | ||
| logger.warning(f"Failed to clean up temp directory: {e}") | ||
|
|
||
| async def _read_decoded_image_via_nixl( | ||
| self, decoded_meta: Dict[str, Any] | ||
| ) -> PIL.Image.Image: | ||
| """Read decoded image via NIXL RDMA and convert to PIL.Image.""" | ||
| # Lazy-init connector | ||
| if self._connector is None: | ||
| self._connector = connect.Connector() | ||
| await self._connector.initialize() | ||
| logger.info("NIXL connector initialized for decoded media") | ||
|
|
||
| # Extract fields | ||
| meta_str = decoded_meta["nixl_metadata"] | ||
| desc = decoded_meta["nixl_descriptor"] | ||
| shape = decoded_meta["shape"] | ||
|
|
||
| # Create tensor to receive RDMA data | ||
| tensor = torch.empty(shape, dtype=torch.uint8) | ||
|
|
||
| # Build RdmaMetadata from frontend-provided descriptor | ||
| # Frontend sends compressed metadata (matches Python nixl_connect) | ||
| rdma_meta = RdmaMetadata( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this work, have you tested it? the "normal flow" is to create a passive operation ( given that this is using an active operation (
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. usually metadata comes from the secondary connection, which in turn got it from its |
||
| descriptors=[ | ||
| SerializedDescriptor( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not pass |
||
| device="cpu" | ||
| if desc.get("mem_type") == "Dram" | ||
| else f"cuda:{desc.get('device_id', 0)}", | ||
| ptr=desc["addr"], | ||
| size=desc["size"], | ||
| ) | ||
| ], | ||
| nixl_metadata=meta_str, | ||
| notification_key=f"img-{shape}", | ||
| operation_kind=int(OperationKind.READ), | ||
| ) | ||
|
|
||
| # RDMA read | ||
| read_op = await self._connector.begin_read( | ||
| rdma_meta, connect.Descriptor(tensor) | ||
| ) | ||
| await read_op.wait_for_completion() | ||
|
Comment on lines
+158
to
+197
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a NIXL expert, so please let me know if I can be doing anything here better. |
||
|
|
||
| # Convert to PIL.Image (assume RGB, handle RGBA/grayscale) | ||
| arr = tensor.numpy() | ||
| modes = {1: "L", 3: "RGB", 4: "RGBA"} | ||
| if modes[shape[2]] == "L": | ||
| arr = arr.squeeze(-1) | ||
| return PIL.Image.fromarray(arr, modes[shape[2]]) | ||
|
|
||
| async def _extract_multimodal_data( | ||
| self, request: Dict[str, Any] | ||
| ) -> Dict[str, Any] | None: | ||
| """ | ||
| Extract and decode multimodal data from PreprocessedRequest. | ||
|
|
||
| Supports two variants: | ||
| 1. Url: Frontend passes URL, backend decodes | ||
| 2. Decoded: Frontend decoded, NIXL RDMA transfer | ||
KrishnanPrash marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| if "multi_modal_data" not in request or request["multi_modal_data"] is None: | ||
| return None | ||
|
|
@@ -165,22 +222,20 @@ async def _extract_multimodal_data( | |
| # Process image_url entries | ||
| images = [] | ||
| for item in mm_map.get(IMAGE_URL_KEY, []): | ||
| if isinstance(item, dict) and URL_VARIANT_KEY in item: | ||
| if isinstance(item, dict) and DECODED_VARIANT_KEY in item: | ||
| decoded_meta = item[DECODED_VARIANT_KEY] | ||
| image = await self._read_decoded_image_via_nixl(decoded_meta) | ||
| images.append(image) | ||
| logger.info( | ||
| f"Using DECODED path: Loaded image via NIXL RDMA " | ||
| f"(shape={decoded_meta.get('shape')}, dtype={decoded_meta.get('dtype')})" | ||
| ) | ||
| elif isinstance(item, dict) and URL_VARIANT_KEY in item: | ||
| url = item[URL_VARIANT_KEY] | ||
| try: | ||
| # ImageLoader supports both data: and http(s): URLs with caching | ||
| image = await self.image_loader.load_image(url) | ||
| images.append(image) | ||
| logger.debug(f"Loaded image from URL: {url[:80]}...") | ||
| except Exception: | ||
| logger.exception(f"Failed to load image from {url[:80]}...") | ||
| raise | ||
| elif isinstance(item, dict) and DECODED_VARIANT_KEY in item: | ||
| # Decoded support from PRs #3971/#3988 (frontend decoding + NIXL transfer) | ||
| # Will contain NIXL metadata for direct memory access | ||
| # TODO: Implement NIXL read when PRs merge | ||
| logger.warning( | ||
| "Decoded multimodal data not yet supported in standard worker" | ||
| image = await self.image_loader.load_image(url) | ||
| images.append(image) | ||
| logger.info( | ||
| f"Using URL path: Loaded image from URL (type={url.split(':')[0]})" | ||
| ) | ||
|
|
||
| if images: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what type is
desc?