Skip to content
11 changes: 11 additions & 0 deletions components/src/dynamo/trtllm/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from enum import Enum


class DisaggregationMode(Enum):
AGGREGATED = "prefill_and_decode"
PREFILL = "prefill"
DECODE = "decode"
ENCODE = "encode"
167 changes: 116 additions & 51 deletions components/src/dynamo/trtllm/encode_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from dataclasses import asdict
from typing import Any, Dict, Union

import torch
from tensorrt_llm.inputs import default_multimodal_input_loader

import dynamo.nixl_connect as nixl_connect
from dynamo.trtllm.utils.disagg_utils import DisaggregatedParamsCodec


class EncodeHelper:
Expand Down Expand Up @@ -185,10 +188,14 @@ async def read_embeddings_from_encode_response(
return encodings_tensor

@staticmethod
async def process_embedding_request(
async def process_encode_request(
request: Dict[str, Any],
multimodal_processor,
connector: nixl_connect.Connector,
tokenizer=None,
model_dir=None,
model_type=None,
engine=None,
):
"""
Process embedding request by loading embeddings and creating NIXL readable operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comments for this function need to be updated.

Expand All @@ -206,57 +213,115 @@ async def process_embedding_request(
messages = request.get("extra_args", {}).get(
"messages", request.get("messages", [])
)
_, _, embedding_paths = multimodal_processor.extract_prompt_and_media(messages)

if not embedding_paths:
# Placeholder for TRTLLM Encoder to be called
# TRTLLM Encoder will return a memory handler on the encoder GPU with the encodings
logging.warning(
"No embedding paths found, NIXL transfer for image urls not supported by TRTLLM Encoder yet"
(
text_prompt,
image_urls,
embedding_paths,
) = multimodal_processor.extract_prompt_and_media(messages)

if embedding_paths:
# Load the embeddings data
loaded_data = multimodal_processor.load_tensor_from_path_or_url(
embedding_paths[0]
)
yield {"error": "No embedding paths found"}
return

# Load the embeddings data
loaded_data = multimodal_processor.load_tensor_from_path_or_url(
embedding_paths[0]
)

# Handle both tensor and dictionary formats
if isinstance(loaded_data, dict):
# Dictionary format (e.g., maverick_mm_embed_seashore_v3.pt)
encodings = loaded_data.get("mm_embeddings")
if encodings is None:
yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"}
# Handle both tensor and dictionary formats
if isinstance(loaded_data, dict):
# Dictionary format (e.g., maverick_mm_embed_seashore_v3.pt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: was this in reference to some local debugging? Should it be removed?

encodings = loaded_data.get("mm_embeddings")
if encodings is None:
yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"}
return

# Store auxiliary data for later transmission
auxiliary_data = {
k: v for k, v in loaded_data.items() if k != "mm_embeddings"
}
else:
# Tensor format (e.g., llava_next_mm_embed_seashore.pt)
encodings = loaded_data
auxiliary_data = {}
# Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation
op_metadata = readable_op.metadata()

# Send back shape info, readable metadata, and serialized auxiliary data
response = {
"nixl_readable_metadata": op_metadata.model_dump(),
"embeddings_shape": list(encodings.shape),
"embeddings_dtype": str(encodings.dtype),
"auxiliary_data": EncodeHelper.serialize_tensor_dict(
auxiliary_data
),
}
yield response

# Wait for the prefill worker to complete the read operation
logging.debug(
"EncodeHelper waiting for PrefillHandler to read embeddings..."
)
await readable_op.wait_for_completion()
logging.debug("EncodeHelper completed readable operation.")
elif image_urls and text_prompt:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add a else case to handle the case if either image_urls or text_prompt or both are empty?

# Use trtllm MultimodalEncoder to generate embeddings
inputs = default_multimodal_input_loader(
tokenizer=tokenizer,
model_dir=model_dir,
model_type=model_type,
modality="image",
prompts=[text_prompt],
media=image_urls[0],
)
# engine.llm is the MultimodalEncoder instance
# MultimodalEncoder.generate() returns a list of GenerationResult objects
encoder_outputs = list(engine.llm.generate(inputs))
if not encoder_outputs:
logging.error("ENCODE WORKER: encoder_outputs is empty")
yield {"ep_disaggregated_params": None}
return

# Store auxiliary data for later transmission
auxiliary_data = {
k: v for k, v in loaded_data.items() if k != "mm_embeddings"
}
else:
# Tensor format (e.g., llava_next_mm_embed_seashore.pt)
encodings = loaded_data
auxiliary_data = {}

# Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation
op_metadata = readable_op.metadata()

# Send back shape info, readable metadata, and serialized auxiliary data
response = {
"nixl_readable_metadata": op_metadata.model_dump(),
"embeddings_shape": list(encodings.shape),
"embeddings_dtype": str(encodings.dtype),
"auxiliary_data": EncodeHelper.serialize_tensor_dict(auxiliary_data),
}
yield response

# Wait for the prefill worker to complete the read operation
ep_disaggregated_params = encoder_outputs[0].disaggregated_params
if ep_disaggregated_params is None:
logging.error(
"ENCODE WORKER: encoder_outputs[0].disaggregated_params is None"
)
yield {"ep_disaggregated_params": None}
return
if (
hasattr(ep_disaggregated_params, "multimodal_embedding_handles")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe stupid question: aren't we guaranteed that this attribute exists? Couldn't we just check for if ep_disaggregated_params.multimodal_embedding_handles is not None directly?

Or is the idea that we want to support multiple TRTLLM versions somehow?

and ep_disaggregated_params.multimodal_embedding_handles
):
logging.debug(
f"ENCODE WORKER: Generated {len(ep_disaggregated_params.multimodal_embedding_handles)} embedding handle(s)"
)
else:
logging.warning(
"ENCODE WORKER: ep_disaggregated_params has no multimodal_embedding_handles"
)
# Prepare for Network Transfer
encoded_params = DisaggregatedParamsCodec.encode(ep_disaggregated_params)
params_dict = asdict(encoded_params)
# Also send the processed prompt (which includes <image> tokens)
# default_multimodal_input_loader returns a list, get the first element
processed_prompt = None
prompt_token_ids = None
if isinstance(inputs, list) and len(inputs) > 0:
first_input = inputs[0]
if isinstance(first_input, dict):
processed_prompt = first_input.get("prompt")
else:
processed_prompt = getattr(first_input, "prompt", None)
# Tokenize the processed prompt for prefill worker
if processed_prompt and tokenizer is not None:
prompt_token_ids = tokenizer.encode(
processed_prompt, add_special_tokens=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you leave a comment why add_special_tokens is set to False?

)
logging.debug(
"EncodeHelper waiting for PrefillHandler to read embeddings..."
f"ENCODE WORKER: Extracted processed_prompt: {processed_prompt}"
)
await readable_op.wait_for_completion()
logging.debug("EncodeHelper completed readable operation.")
yield {
"ep_disaggregated_params": params_dict,
"processed_prompt": processed_prompt, # Prompt with <image> tokens
"prompt_token_ids": prompt_token_ids, # Token IDs for consistency
}
return
31 changes: 24 additions & 7 deletions components/src/dynamo/trtllm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import enum
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, Optional, Union

from tensorrt_llm import LLM
from tensorrt_llm import LLM, MultimodalEncoder

from dynamo.trtllm.constants import DisaggregationMode

logger = logging.getLogger(__name__)

Expand All @@ -19,8 +21,9 @@ class Backend(str, enum.Enum):


class TensorRTLLMEngine:
def __init__(self, engine_args):
def __init__(self, engine_args, disaggregation_mode: DisaggregationMode):
self._llm: Optional[LLM] = None
self.disaggregation_mode = disaggregation_mode
backend = engine_args.pop("backend", Backend.PYTORCH)
if backend == Backend.PYTORCH:
self._llm_cls = LLM
Expand All @@ -38,7 +41,19 @@ def __init__(self, engine_args):

async def initialize(self):
if not self._llm:
self._llm = self._llm_cls(**self.engine_args)
if self.disaggregation_mode == DisaggregationMode.ENCODE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, how is the engine initialized for a prefill worker? Is that encapsulated via engine_args itself somehow? (It might be worth leaving a comment whatever the case is 🙏 )

# Initialize the multimodal encoder for full EPD
max_batch_size = self.engine_args.pop("max_batch_size", 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why was this necessary? Maybe leave a comment? 🙏

model = self.engine_args.pop("model")
logging.info(
f"Initializing multimodal encoder with max_batch_size: {max_batch_size}"
)
self._llm = MultimodalEncoder(
model=model,
max_batch_size=max_batch_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why not forward the rest of the self.engine_args? MultimodalEncoder is also a LLM class: https://github.com/NVIDIA/TensorRT-LLM/blob/v1.2.0rc4/tensorrt_llm/llmapi/mm_encoder.py#L16

)
else:
self._llm = self._llm_cls(**self.engine_args)

async def cleanup(self):
if self._llm:
Expand All @@ -50,7 +65,7 @@ async def cleanup(self):
self._llm = None

@property
def llm(self):
def llm(self) -> Union[LLM, MultimodalEncoder]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not self._llm:
raise RuntimeError("Engine not initialized")
return self._llm
Expand Down Expand Up @@ -91,8 +106,10 @@ def _warn_about_unsupported_field(field_name: str) -> None:


@asynccontextmanager
async def get_llm_engine(engine_args) -> AsyncGenerator[TensorRTLLMEngine, None]:
engine = TensorRTLLMEngine(engine_args)
async def get_llm_engine(
engine_args, disaggregation_mode: DisaggregationMode
) -> AsyncGenerator[TensorRTLLMEngine, None]:
engine = TensorRTLLMEngine(engine_args, disaggregation_mode)
try:
await engine.initialize()
yield engine
Expand Down
2 changes: 1 addition & 1 deletion components/src/dynamo/trtllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config}
)

async with get_llm_engine(engine_args) as engine:
async with get_llm_engine(engine_args, config.disaggregation_mode) as engine:
endpoint = component.endpoint(config.endpoint)

# should ideally call get_engine_runtime_config
Expand Down
36 changes: 29 additions & 7 deletions components/src/dynamo/trtllm/multimodal_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,34 @@ def extract_prompt_and_media(
return " ".join(text_parts), image_urls, embedding_paths

async def process_openai_request(
self, request: Dict, embeddings: Any
self, request: Dict, embeddings: Any, ep_disaggregated_params: Any
) -> Optional[Any]:
"""Process OpenAI request and return with multimodal data."""
# Extract messages - check extra_args first (from Rust preprocessor for multimodal)
# Fall back to direct messages field for backward compatibility
self.previous_decoded_text = ""
messages = request.get("extra_args", {}).get(
"messages", request.get("messages", [])
)
text_prompt, image_urls, embedding_paths = self.extract_prompt_and_media(
messages
)

if not image_urls and not embedding_paths:
if not image_urls and not embedding_paths and not ep_disaggregated_params:
logging.warning("No multimodal content, returning None")
return None

processed_prompt_from_encoder = request.get("_epd_processed_prompt")

# Only use EPD flow if we actually have encoder data
# For PD flow (no encoder), fall through to embedding_paths handling
if processed_prompt_from_encoder is not None:
text_prompt = processed_prompt_from_encoder
result = {"prompt": text_prompt}
if "_epd_prompt_token_ids" in request and request["_epd_prompt_token_ids"]:
result["prompt_token_ids"] = request["_epd_prompt_token_ids"]
else:
logging.warning("MM PROCESSOR: No prompt_token_ids from encoder")
return result
loader_kwargs = {}
if embeddings is not None:
# EPD flow
Expand Down Expand Up @@ -214,10 +226,20 @@ def create_response_chunk(
if self.tokenizer is None:
raise ValueError("Tokenizer must be provided for creating response chunks.")

new_tokens = output.token_ids[num_output_tokens_so_far:]
# Decode the new token IDs into a string. This is the incremental piece
# of text to be sent to the client.
delta_text = self.tokenizer.decode(new_tokens)
all_tokens = output.token_ids
current_text = self.tokenizer.decode(
all_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
if num_output_tokens_so_far == 0:
# First chunk: use all decoded text
delta_text = current_text
# Store for next iteration
self.previous_decoded_text = current_text
else:
# Incremental chunk: extract delta using cached previous text
delta_text = current_text[len(self.previous_decoded_text) :]
# Update cache for next iteration
self.previous_decoded_text = current_text
# Assemble the delta payload for the response chunk.
delta = {"content": delta_text if delta_text else ""}
if num_output_tokens_so_far == 0:
Expand Down
Loading
Loading