-
Notifications
You must be signed in to change notification settings - Fork 738
feat: Standalone encoder in dynamo trtllm #4668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3438e07
9d831c8
86a2b46
6b58e80
60faba0
0fb6e95
fa7136e
2d6a358
8e60929
ae9b10b
fdb86d3
9ccb564
6101937
1f152c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from enum import Enum | ||
|
|
||
|
|
||
| class DisaggregationMode(Enum): | ||
| AGGREGATED = "prefill_and_decode" | ||
| PREFILL = "prefill" | ||
| DECODE = "decode" | ||
| ENCODE = "encode" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,11 +2,14 @@ | |
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import logging | ||
| from dataclasses import asdict | ||
| from typing import Any, Dict, Union | ||
|
|
||
| import torch | ||
| from tensorrt_llm.inputs import default_multimodal_input_loader | ||
|
|
||
| import dynamo.nixl_connect as nixl_connect | ||
| from dynamo.trtllm.utils.disagg_utils import DisaggregatedParamsCodec | ||
|
|
||
|
|
||
| class EncodeHelper: | ||
|
|
@@ -185,10 +188,14 @@ async def read_embeddings_from_encode_response( | |
| return encodings_tensor | ||
|
|
||
| @staticmethod | ||
| async def process_embedding_request( | ||
| async def process_encode_request( | ||
| request: Dict[str, Any], | ||
| multimodal_processor, | ||
| connector: nixl_connect.Connector, | ||
| tokenizer=None, | ||
| model_dir=None, | ||
| model_type=None, | ||
| engine=None, | ||
| ): | ||
| """ | ||
| Process embedding request by loading embeddings and creating NIXL readable operation. | ||
|
|
@@ -206,57 +213,115 @@ async def process_embedding_request( | |
| messages = request.get("extra_args", {}).get( | ||
| "messages", request.get("messages", []) | ||
| ) | ||
| _, _, embedding_paths = multimodal_processor.extract_prompt_and_media(messages) | ||
|
|
||
| if not embedding_paths: | ||
| # Placeholder for TRTLLM Encoder to be called | ||
| # TRTLLM Encoder will return a memory handler on the encoder GPU with the encodings | ||
| logging.warning( | ||
| "No embedding paths found, NIXL transfer for image urls not supported by TRTLLM Encoder yet" | ||
| ( | ||
| text_prompt, | ||
| image_urls, | ||
| embedding_paths, | ||
| ) = multimodal_processor.extract_prompt_and_media(messages) | ||
|
|
||
| if embedding_paths: | ||
| # Load the embeddings data | ||
| loaded_data = multimodal_processor.load_tensor_from_path_or_url( | ||
| embedding_paths[0] | ||
| ) | ||
| yield {"error": "No embedding paths found"} | ||
| return | ||
|
|
||
| # Load the embeddings data | ||
| loaded_data = multimodal_processor.load_tensor_from_path_or_url( | ||
| embedding_paths[0] | ||
| ) | ||
|
|
||
| # Handle both tensor and dictionary formats | ||
| if isinstance(loaded_data, dict): | ||
| # Dictionary format (e.g., maverick_mm_embed_seashore_v3.pt) | ||
| encodings = loaded_data.get("mm_embeddings") | ||
| if encodings is None: | ||
| yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"} | ||
| # Handle both tensor and dictionary formats | ||
| if isinstance(loaded_data, dict): | ||
| # Dictionary format (e.g., maverick_mm_embed_seashore_v3.pt) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: was this in reference to some local debugging? Should it be removed? |
||
| encodings = loaded_data.get("mm_embeddings") | ||
| if encodings is None: | ||
| yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"} | ||
| return | ||
|
|
||
| # Store auxiliary data for later transmission | ||
| auxiliary_data = { | ||
| k: v for k, v in loaded_data.items() if k != "mm_embeddings" | ||
| } | ||
| else: | ||
| # Tensor format (e.g., llava_next_mm_embed_seashore.pt) | ||
| encodings = loaded_data | ||
| auxiliary_data = {} | ||
| # Create readable operation with main embeddings tensor (works for both formats) | ||
| descriptor = nixl_connect.Descriptor(encodings) | ||
| with connector.create_readable(descriptor) as readable_op: | ||
| # Get the metadata for the readable operation | ||
| op_metadata = readable_op.metadata() | ||
|
|
||
| # Send back shape info, readable metadata, and serialized auxiliary data | ||
| response = { | ||
| "nixl_readable_metadata": op_metadata.model_dump(), | ||
| "embeddings_shape": list(encodings.shape), | ||
| "embeddings_dtype": str(encodings.dtype), | ||
| "auxiliary_data": EncodeHelper.serialize_tensor_dict( | ||
| auxiliary_data | ||
| ), | ||
| } | ||
| yield response | ||
|
|
||
| # Wait for the prefill worker to complete the read operation | ||
| logging.debug( | ||
| "EncodeHelper waiting for PrefillHandler to read embeddings..." | ||
| ) | ||
| await readable_op.wait_for_completion() | ||
| logging.debug("EncodeHelper completed readable operation.") | ||
| elif image_urls and text_prompt: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to add a |
||
| # Use trtllm MultimodalEncoder to generate embeddings | ||
| inputs = default_multimodal_input_loader( | ||
| tokenizer=tokenizer, | ||
| model_dir=model_dir, | ||
| model_type=model_type, | ||
| modality="image", | ||
| prompts=[text_prompt], | ||
| media=image_urls[0], | ||
| ) | ||
| # engine.llm is the MultimodalEncoder instance | ||
| # MultimodalEncoder.generate() returns a list of GenerationResult objects | ||
| encoder_outputs = list(engine.llm.generate(inputs)) | ||
| if not encoder_outputs: | ||
| logging.error("ENCODE WORKER: encoder_outputs is empty") | ||
| yield {"ep_disaggregated_params": None} | ||
| return | ||
|
|
||
| # Store auxiliary data for later transmission | ||
| auxiliary_data = { | ||
| k: v for k, v in loaded_data.items() if k != "mm_embeddings" | ||
| } | ||
| else: | ||
| # Tensor format (e.g., llava_next_mm_embed_seashore.pt) | ||
| encodings = loaded_data | ||
| auxiliary_data = {} | ||
|
|
||
| # Create readable operation with main embeddings tensor (works for both formats) | ||
| descriptor = nixl_connect.Descriptor(encodings) | ||
| with connector.create_readable(descriptor) as readable_op: | ||
| # Get the metadata for the readable operation | ||
| op_metadata = readable_op.metadata() | ||
|
|
||
| # Send back shape info, readable metadata, and serialized auxiliary data | ||
| response = { | ||
| "nixl_readable_metadata": op_metadata.model_dump(), | ||
| "embeddings_shape": list(encodings.shape), | ||
| "embeddings_dtype": str(encodings.dtype), | ||
| "auxiliary_data": EncodeHelper.serialize_tensor_dict(auxiliary_data), | ||
| } | ||
| yield response | ||
|
|
||
| # Wait for the prefill worker to complete the read operation | ||
| ep_disaggregated_params = encoder_outputs[0].disaggregated_params | ||
| if ep_disaggregated_params is None: | ||
| logging.error( | ||
| "ENCODE WORKER: encoder_outputs[0].disaggregated_params is None" | ||
| ) | ||
| yield {"ep_disaggregated_params": None} | ||
| return | ||
| if ( | ||
| hasattr(ep_disaggregated_params, "multimodal_embedding_handles") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe stupid question: aren't we guaranteed that this attribute exists? Couldn't we just check for Or is the idea that we want to support multiple TRTLLM versions somehow? |
||
| and ep_disaggregated_params.multimodal_embedding_handles | ||
| ): | ||
| logging.debug( | ||
| f"ENCODE WORKER: Generated {len(ep_disaggregated_params.multimodal_embedding_handles)} embedding handle(s)" | ||
| ) | ||
| else: | ||
| logging.warning( | ||
| "ENCODE WORKER: ep_disaggregated_params has no multimodal_embedding_handles" | ||
| ) | ||
| # Prepare for Network Transfer | ||
| encoded_params = DisaggregatedParamsCodec.encode(ep_disaggregated_params) | ||
| params_dict = asdict(encoded_params) | ||
| # Also send the processed prompt (which includes <image> tokens) | ||
| # default_multimodal_input_loader returns a list, get the first element | ||
| processed_prompt = None | ||
| prompt_token_ids = None | ||
| if isinstance(inputs, list) and len(inputs) > 0: | ||
| first_input = inputs[0] | ||
| if isinstance(first_input, dict): | ||
| processed_prompt = first_input.get("prompt") | ||
| else: | ||
| processed_prompt = getattr(first_input, "prompt", None) | ||
| # Tokenize the processed prompt for prefill worker | ||
| if processed_prompt and tokenizer is not None: | ||
| prompt_token_ids = tokenizer.encode( | ||
| processed_prompt, add_special_tokens=False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you leave a comment why |
||
| ) | ||
| logging.debug( | ||
| "EncodeHelper waiting for PrefillHandler to read embeddings..." | ||
| f"ENCODE WORKER: Extracted processed_prompt: {processed_prompt}" | ||
| ) | ||
| await readable_op.wait_for_completion() | ||
| logging.debug("EncodeHelper completed readable operation.") | ||
| yield { | ||
| "ep_disaggregated_params": params_dict, | ||
| "processed_prompt": processed_prompt, # Prompt with <image> tokens | ||
| "prompt_token_ids": prompt_token_ids, # Token IDs for consistency | ||
| } | ||
| return | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,9 +4,11 @@ | |
| import enum | ||
| import logging | ||
| from contextlib import asynccontextmanager | ||
| from typing import AsyncGenerator, Optional | ||
| from typing import AsyncGenerator, Optional, Union | ||
|
|
||
| from tensorrt_llm import LLM | ||
| from tensorrt_llm import LLM, MultimodalEncoder | ||
|
|
||
| from dynamo.trtllm.constants import DisaggregationMode | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -19,8 +21,9 @@ class Backend(str, enum.Enum): | |
|
|
||
|
|
||
| class TensorRTLLMEngine: | ||
| def __init__(self, engine_args): | ||
| def __init__(self, engine_args, disaggregation_mode: DisaggregationMode): | ||
| self._llm: Optional[LLM] = None | ||
| self.disaggregation_mode = disaggregation_mode | ||
| backend = engine_args.pop("backend", Backend.PYTORCH) | ||
| if backend == Backend.PYTORCH: | ||
| self._llm_cls = LLM | ||
|
|
@@ -38,7 +41,19 @@ def __init__(self, engine_args): | |
|
|
||
| async def initialize(self): | ||
| if not self._llm: | ||
| self._llm = self._llm_cls(**self.engine_args) | ||
| if self.disaggregation_mode == DisaggregationMode.ENCODE: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity, how is the engine initialized for a prefill worker? Is that encapsulated via |
||
| # Initialize the multimodal encoder for full EPD | ||
| max_batch_size = self.engine_args.pop("max_batch_size", 1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity, why was this necessary? Maybe leave a comment? 🙏 |
||
| model = self.engine_args.pop("model") | ||
| logging.info( | ||
| f"Initializing multimodal encoder with max_batch_size: {max_batch_size}" | ||
| ) | ||
| self._llm = MultimodalEncoder( | ||
| model=model, | ||
| max_batch_size=max_batch_size, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity, why not forward the rest of the |
||
| ) | ||
| else: | ||
| self._llm = self._llm_cls(**self.engine_args) | ||
|
|
||
| async def cleanup(self): | ||
| if self._llm: | ||
|
|
@@ -50,7 +65,7 @@ async def cleanup(self): | |
| self._llm = None | ||
|
|
||
| @property | ||
| def llm(self): | ||
| def llm(self) -> Union[LLM, MultimodalEncoder]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: could just be the |
||
| if not self._llm: | ||
| raise RuntimeError("Engine not initialized") | ||
| return self._llm | ||
|
|
@@ -91,8 +106,10 @@ def _warn_about_unsupported_field(field_name: str) -> None: | |
|
|
||
|
|
||
| @asynccontextmanager | ||
| async def get_llm_engine(engine_args) -> AsyncGenerator[TensorRTLLMEngine, None]: | ||
| engine = TensorRTLLMEngine(engine_args) | ||
| async def get_llm_engine( | ||
| engine_args, disaggregation_mode: DisaggregationMode | ||
| ) -> AsyncGenerator[TensorRTLLMEngine, None]: | ||
| engine = TensorRTLLMEngine(engine_args, disaggregation_mode) | ||
| try: | ||
| await engine.initialize() | ||
| yield engine | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: comments for this function need to be updated.