Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions components/backends/trtllm/engine_configs/multimodal/encode.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
trust_remote_code: true
backend: pytorch
disable_overlap_scheduler: false

cuda_graph_config:
max_batch_size: 16

kv_cache_config:
free_gpu_memory_fraction: 0.85
enable_block_reuse: false

cache_transceiver_config:
backend: DEFAULT
10 changes: 5 additions & 5 deletions components/backends/trtllm/launch/epd_disagg.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# SPDX-License-Identifier: Apache-2.0

# Environment variables with defaults
export MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2-VL-7B-Instruct"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"Qwen/Qwen2-VL-7B-Instruct"}
export MODEL_PATH=${MODEL_PATH:-"llava-hf/llava-v1.6-mistral-7b-hf"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"llava-v1.6-mistral-7b-hf"}
export DISAGGREGATION_STRATEGY=${DISAGGREGATION_STRATEGY:-"decode_first"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"engine_configs/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"engine_configs/decode.yaml"}
export ENCODE_ENGINE_ARGS=${ENCODE_ENGINE_ARGS:-"engine_configs/encode.yaml"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"engine_configs/multimodal/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"engine_configs/multimodal/decode.yaml"}
export ENCODE_ENGINE_ARGS=${ENCODE_ENGINE_ARGS:-"engine_configs/multimodal/encode.yaml"}
export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0"}
export DECODE_CUDA_VISIBLE_DEVICES=${DECODE_CUDA_VISIBLE_DEVICES:-"1"}
export ENCODE_CUDA_VISIBLE_DEVICES=${ENCODE_CUDA_VISIBLE_DEVICES:-"2"}
Expand Down
16 changes: 16 additions & 0 deletions components/src/dynamo/trtllm/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from enum import Enum


class DisaggregationMode(Enum):
AGGREGATED = "prefill_and_decode"
PREFILL = "prefill"
DECODE = "decode"
ENCODE = "encode"


class DisaggregationStrategy(Enum):
PREFILL_FIRST = "prefill_first"
DECODE_FIRST = "decode_first"
170 changes: 118 additions & 52 deletions components/src/dynamo/trtllm/encode_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from dataclasses import asdict
from typing import Any, Dict, Union

import torch
from tensorrt_llm.inputs import default_multimodal_input_loader

import dynamo.nixl_connect as nixl_connect
from dynamo.trtllm.utils.disagg_utils import DisaggregatedParamsCodec


class EncodeHelper:
Expand Down Expand Up @@ -185,10 +188,14 @@ async def read_embeddings_from_encode_response(
return encodings_tensor

@staticmethod
async def process_embedding_request(
async def process_encode_request(
request: Dict[str, Any],
multimodal_processor,
connector: nixl_connect.Connector,
tokenizer = None,
model_dir = None,
model_type = None,
engine = None,
):
"""
Process embedding request by loading embeddings and creating NIXL readable operation.
Expand All @@ -203,57 +210,116 @@ async def process_embedding_request(
"""
# Load embeddings first to get the actual shape
messages = request.get("messages", [])
_, _, embedding_paths = multimodal_processor.extract_prompt_and_media(messages)

if not embedding_paths:
# Placeholder for TRTLLM Encoder to be called
# TRTLLM Encoder will return a memory handler on the encoder GPU with the encodings
logging.warning(
"No embedding paths found, NIXL transfer for image urls not supported by TRTLLM Encoder yet"
(
text_prompt,
image_urls,
embedding_paths,
) = multimodal_processor.extract_prompt_and_media(messages)
if embedding_paths:
# Load the embeddings data
loaded_data = multimodal_processor.load_tensor_from_path_or_url(
embedding_paths[0]
)
yield {"error": "No embedding paths found"}
return

# Load the embeddings data
loaded_data = multimodal_processor.load_tensor_from_path_or_url(
embedding_paths[0]
)

# Handle both tensor and dictionary formats
if isinstance(loaded_data, dict):
# Dictionary format (e.g., maverick_mm_embed_seashore_v3.pt)
encodings = loaded_data.get("mm_embeddings")
if encodings is None:
yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"}
return

# Store auxiliary data for later transmission
auxiliary_data = {
k: v for k, v in loaded_data.items() if k != "mm_embeddings"
}
# Handle both tensor and dictionary formats
if isinstance(loaded_data, dict):
# Dictionary format (e.g., maverick_mm_embed_seashore_v3.pt)
encodings = loaded_data.get("mm_embeddings")
if encodings is None:
yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"}
return

# Store auxiliary data for later transmission
auxiliary_data = {
k: v for k, v in loaded_data.items() if k != "mm_embeddings"
}
else:
# Tensor format (e.g., llava_next_mm_embed_seashore.pt)
encodings = loaded_data
auxiliary_data = {}

# Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation
op_metadata = readable_op.metadata()

# Send back shape info, readable metadata, and serialized auxiliary data
response = {
"nixl_readable_metadata": op_metadata.model_dump(),
"embeddings_shape": list(encodings.shape),
"embeddings_dtype": str(encodings.dtype),
"auxiliary_data": EncodeHelper.serialize_tensor_dict(auxiliary_data),
}
yield response

# Wait for the prefill worker to complete the read operation
logging.debug(
"EncodeHelper waiting for PrefillHandler to read embeddings..."
)
await readable_op.wait_for_completion()
logging.debug("EncodeHelper completed readable operation.")
else:
# Tensor format (e.g., llava_next_mm_embed_seashore.pt)
encodings = loaded_data
auxiliary_data = {}

# Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation
op_metadata = readable_op.metadata()

# Send back shape info, readable metadata, and serialized auxiliary data
response = {
"nixl_readable_metadata": op_metadata.model_dump(),
"embeddings_shape": list(encodings.shape),
"embeddings_dtype": str(encodings.dtype),
"auxiliary_data": EncodeHelper.serialize_tensor_dict(auxiliary_data),
}
yield response

# Wait for the prefill worker to complete the read operation
logging.debug(
"EncodeHelper waiting for PrefillHandler to read embeddings..."
logging.info("========== ENCODE WORKER: Full EPD - Using MultimodalEncoder ==========")
inputs = default_multimodal_input_loader(
tokenizer=tokenizer,
model_dir=model_dir,
model_type=model_type,
modality="image",
prompts=[text_prompt],
media=image_urls[0],
)
await readable_op.wait_for_completion()
logging.debug("EncodeHelper completed readable operation.")
# engine.llm is the MultimodalEncoder instance
# MultimodalEncoder.generate() returns a list of GenerationResult objects
encoder_outputs = list(engine.llm.generate(inputs))

if not encoder_outputs:
logging.error("ENCODE WORKER: encoder_outputs is empty")
yield {"ep_disaggregated_params": None}
return

logging.info(f"ENCODE WORKER: Received {len(encoder_outputs)} encoder output(s)")
ep_disaggregated_params = encoder_outputs[0].disaggregated_params

if ep_disaggregated_params is None:
logging.error("ENCODE WORKER: encoder_outputs[0].disaggregated_params is None")
yield {"ep_disaggregated_params": None}
return

if hasattr(ep_disaggregated_params, 'multimodal_embedding_handles') and ep_disaggregated_params.multimodal_embedding_handles:
logging.info(f"ENCODE WORKER: Generated {len(ep_disaggregated_params.multimodal_embedding_handles)} embedding handle(s)")
else:
logging.warning("ENCODE WORKER: ep_disaggregated_params has no multimodal_embedding_handles")

# Convert DisaggregatedParams to dict for JSON serialization over the network
# Use the same pattern as handler_base.py: asdict(DisaggregatedParamsCodec.encode(...))
logging.info(f"ENCODE WORKER: Before codec.encode - multimodal_embedding_handles: {getattr(ep_disaggregated_params, 'multimodal_embedding_handles', 'NOT FOUND')}")
encoded_params = DisaggregatedParamsCodec.encode(ep_disaggregated_params)
logging.info(f"ENCODE WORKER: After codec.encode - multimodal_embedding_handles: {getattr(encoded_params, 'multimodal_embedding_handles', 'NOT FOUND')}")
params_dict = asdict(encoded_params)
logging.info(f"ENCODE WORKER: After asdict - multimodal_embedding_handles in dict: {params_dict.get('multimodal_embedding_handles', 'NOT FOUND')}")

# Also send the processed prompt (which includes <image> tokens)
# default_multimodal_input_loader returns a list, get the first element
processed_prompt = None
prompt_token_ids = None

if isinstance(inputs, list) and len(inputs) > 0:
first_input = inputs[0]
if isinstance(first_input, dict):
processed_prompt = first_input.get("prompt")
else:
processed_prompt = getattr(first_input, 'prompt', None)

# Tokenize the processed prompt for prefill worker
if processed_prompt and tokenizer is not None:
prompt_token_ids = tokenizer.encode(processed_prompt, add_special_tokens=False)
logging.info(f"ENCODE WORKER: Tokenized processed_prompt (length={len(prompt_token_ids)})")

logging.info(f"ENCODE WORKER: Extracted processed_prompt: {processed_prompt}")

yield {
"ep_disaggregated_params": params_dict,
"processed_prompt": processed_prompt, # Prompt with <image> tokens
"prompt_token_ids": prompt_token_ids # Token IDs for consistency
}
return
48 changes: 37 additions & 11 deletions components/src/dynamo/trtllm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,39 @@

import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, Optional, Union

from tensorrt_llm import LLM
from tensorrt_llm import LLM, MultimodalEncoder

from dynamo.trtllm.constants import DisaggregationMode

logging.basicConfig(level=logging.DEBUG)


class TensorRTLLMEngine:
def __init__(self, engine_args):
def __init__(self, engine_args, disaggregation_mode=None):
self.engine_args = engine_args
self._llm: Optional[LLM] = None
self._llm: Optional[Union[LLM, MultimodalEncoder]] = None
self.disaggregation_mode = disaggregation_mode

async def initialize(self):
if not self._llm:
model = self.engine_args.pop("model")
self._llm = LLM(
model=model,
**self.engine_args,
)
if self.disaggregation_mode == DisaggregationMode.ENCODE:
# Initialize MultimodalEncoder for EPD flow
max_batch_size = self.engine_args.pop("max_batch_size", 1)
logging.info(
f"Initializing MultimodalEncoder with max_batch_size={max_batch_size}"
)
self._llm = MultimodalEncoder(
model=model,
max_batch_size=max_batch_size,
)
else:
self._llm = LLM(
model=model,
**self.engine_args,
)

async def cleanup(self):
if self._llm:
Expand All @@ -33,15 +47,27 @@ async def cleanup(self):
self._llm = None

@property
def llm(self):
def llm(self) -> Union[LLM, MultimodalEncoder]:
if not self._llm:
raise RuntimeError("Engine not initialized")
return self._llm


@asynccontextmanager
async def get_llm_engine(engine_args) -> AsyncGenerator[TensorRTLLMEngine, None]:
engine = TensorRTLLMEngine(engine_args)
async def get_llm_engine(
engine_args, disaggregation_mode=None
) -> AsyncGenerator[TensorRTLLMEngine, None]:
if disaggregation_mode == DisaggregationMode.DECODE:
logging.info(
f"Disaggregation mode: {disaggregation_mode}, setting disable_overlap_scheduler to False"
)
engine_args["disable_overlap_scheduler"] = False
elif disaggregation_mode == DisaggregationMode.PREFILL:
logging.info(
f"Disaggregation mode: {disaggregation_mode}, setting disable_overlap_scheduler to True"
)
engine_args["disable_overlap_scheduler"] = True
engine = TensorRTLLMEngine(engine_args, disaggregation_mode)
try:
await engine.initialize()
yield engine
Expand Down
Loading
Loading