Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Draft PR
Signed-off-by: Indrajit Bhosale <[email protected]>
  • Loading branch information
indrajit96 committed Oct 22, 2025
commit b0d0dbe99b3e285403bd65e9774fdd510da3b042
4 changes: 2 additions & 2 deletions components/backends/trtllm/launch/epd_disagg.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# SPDX-License-Identifier: Apache-2.0

# Environment variables with defaults
export MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2-VL-7B-Instruct"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"Qwen/Qwen2-VL-7B-Instruct"}
export MODEL_PATH=${MODEL_PATH:-"llava-hf/llava-v1.6-mistral-7b-hf"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"llava-v1.6-mistral-7b-hf"}
export DISAGGREGATION_STRATEGY=${DISAGGREGATION_STRATEGY:-"decode_first"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"engine_configs/multimodal/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"engine_configs/multimodal/decode.yaml"}
Expand Down
16 changes: 16 additions & 0 deletions components/src/dynamo/trtllm/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from enum import Enum


class DisaggregationMode(Enum):
AGGREGATED = "prefill_and_decode"
PREFILL = "prefill"
DECODE = "decode"
ENCODE = "encode"


class DisaggregationStrategy(Enum):
PREFILL_FIRST = "prefill_first"
DECODE_FIRST = "decode_first"
170 changes: 118 additions & 52 deletions components/src/dynamo/trtllm/encode_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from dataclasses import asdict
from typing import Any, Dict, Union

import torch
from tensorrt_llm.inputs import default_multimodal_input_loader

import dynamo.nixl_connect as nixl_connect
from dynamo.trtllm.utils.disagg_utils import DisaggregatedParamsCodec


class EncodeHelper:
Expand Down Expand Up @@ -185,10 +188,14 @@ async def read_embeddings_from_encode_response(
return encodings_tensor

@staticmethod
async def process_embedding_request(
async def process_encode_request(
request: Dict[str, Any],
multimodal_processor,
connector: nixl_connect.Connector,
tokenizer = None,
model_dir = None,
model_type = None,
engine = None,
):
"""
Process embedding request by loading embeddings and creating NIXL readable operation.
Expand All @@ -203,57 +210,116 @@ async def process_embedding_request(
"""
# Load embeddings first to get the actual shape
messages = request.get("messages", [])
_, _, embedding_paths = multimodal_processor.extract_prompt_and_media(messages)

if not embedding_paths:
# Placeholder for TRTLLM Encoder to be called
# TRTLLM Encoder will return a memory handler on the encoder GPU with the encodings
logging.warning(
"No embedding paths found, NIXL transfer for image urls not supported by TRTLLM Encoder yet"
(
text_prompts,
image_urls,
embedding_paths,
) = multimodal_processor.extract_prompt_and_media(messages)
if embedding_paths:
# Load the embeddings data
loaded_data = multimodal_processor.load_tensor_from_path_or_url(
embedding_paths[0]
)
yield {"error": "No embedding paths found"}
return

# Load the embeddings data
loaded_data = multimodal_processor.load_tensor_from_path_or_url(
embedding_paths[0]
)

# Handle both tensor and dictionary formats
if isinstance(loaded_data, dict):
# Dictionary format (e.g., maverick_mm_embed_seashore_v3.pt)
encodings = loaded_data.get("mm_embeddings")
if encodings is None:
yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"}
return

# Store auxiliary data for later transmission
auxiliary_data = {
k: v for k, v in loaded_data.items() if k != "mm_embeddings"
}
# Handle both tensor and dictionary formats
if isinstance(loaded_data, dict):
# Dictionary format (e.g., maverick_mm_embed_seashore_v3.pt)
encodings = loaded_data.get("mm_embeddings")
if encodings is None:
yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"}
return

# Store auxiliary data for later transmission
auxiliary_data = {
k: v for k, v in loaded_data.items() if k != "mm_embeddings"
}
else:
# Tensor format (e.g., llava_next_mm_embed_seashore.pt)
encodings = loaded_data
auxiliary_data = {}

# Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation
op_metadata = readable_op.metadata()

# Send back shape info, readable metadata, and serialized auxiliary data
response = {
"nixl_readable_metadata": op_metadata.model_dump(),
"embeddings_shape": list(encodings.shape),
"embeddings_dtype": str(encodings.dtype),
"auxiliary_data": EncodeHelper.serialize_tensor_dict(auxiliary_data),
}
yield response

# Wait for the prefill worker to complete the read operation
logging.debug(
"EncodeHelper waiting for PrefillHandler to read embeddings..."
)
await readable_op.wait_for_completion()
logging.debug("EncodeHelper completed readable operation.")
else:
# Tensor format (e.g., llava_next_mm_embed_seashore.pt)
encodings = loaded_data
auxiliary_data = {}

# Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation
op_metadata = readable_op.metadata()

# Send back shape info, readable metadata, and serialized auxiliary data
response = {
"nixl_readable_metadata": op_metadata.model_dump(),
"embeddings_shape": list(encodings.shape),
"embeddings_dtype": str(encodings.dtype),
"auxiliary_data": EncodeHelper.serialize_tensor_dict(auxiliary_data),
}
yield response

# Wait for the prefill worker to complete the read operation
logging.debug(
"EncodeHelper waiting for PrefillHandler to read embeddings..."
logging.info("========== ENCODE WORKER: Full EPD - Using MultimodalEncoder ==========")
inputs = default_multimodal_input_loader(
tokenizer=tokenizer,
model_dir=model_dir,
model_type=model_type,
modality="image",
prompts=text_prompts[0],
media=image_urls[0],
)
await readable_op.wait_for_completion()
logging.debug("EncodeHelper completed readable operation.")
# engine.llm is the MultimodalEncoder instance
# MultimodalEncoder.generate() returns a list of GenerationResult objects
encoder_outputs = list(engine.llm.generate(inputs))

if not encoder_outputs:
logging.error("ENCODE WORKER: encoder_outputs is empty")
yield {"ep_disaggregated_params": None}
return

logging.info(f"ENCODE WORKER: Received {len(encoder_outputs)} encoder output(s)")
ep_disaggregated_params = encoder_outputs[0].disaggregated_params

if ep_disaggregated_params is None:
logging.error("ENCODE WORKER: encoder_outputs[0].disaggregated_params is None")
yield {"ep_disaggregated_params": None}
return

if hasattr(ep_disaggregated_params, 'multimodal_embedding_handles') and ep_disaggregated_params.multimodal_embedding_handles:
logging.info(f"ENCODE WORKER: Generated {len(ep_disaggregated_params.multimodal_embedding_handles)} embedding handle(s)")
else:
logging.warning("ENCODE WORKER: ep_disaggregated_params has no multimodal_embedding_handles")

# Convert DisaggregatedParams to dict for JSON serialization over the network
# Use the same pattern as handler_base.py: asdict(DisaggregatedParamsCodec.encode(...))
logging.info(f"ENCODE WORKER: Before codec.encode - multimodal_embedding_handles: {getattr(ep_disaggregated_params, 'multimodal_embedding_handles', 'NOT FOUND')}")
encoded_params = DisaggregatedParamsCodec.encode(ep_disaggregated_params)
logging.info(f"ENCODE WORKER: After codec.encode - multimodal_embedding_handles: {getattr(encoded_params, 'multimodal_embedding_handles', 'NOT FOUND')}")
params_dict = asdict(encoded_params)
logging.info(f"ENCODE WORKER: After asdict - multimodal_embedding_handles in dict: {params_dict.get('multimodal_embedding_handles', 'NOT FOUND')}")

# Also send the processed prompt (which includes <image> tokens)
# default_multimodal_input_loader returns a list, get the first element
processed_prompt = None
prompt_token_ids = None

if isinstance(inputs, list) and len(inputs) > 0:
first_input = inputs[0]
if isinstance(first_input, dict):
processed_prompt = first_input.get("prompt")
else:
processed_prompt = getattr(first_input, 'prompt', None)

# Tokenize the processed prompt for prefill worker
if processed_prompt and tokenizer is not None:
prompt_token_ids = tokenizer.encode(processed_prompt, add_special_tokens=False)
logging.info(f"ENCODE WORKER: Tokenized processed_prompt (length={len(prompt_token_ids)})")

logging.info(f"ENCODE WORKER: Extracted processed_prompt: {processed_prompt}")

yield {
"ep_disaggregated_params": params_dict,
"processed_prompt": processed_prompt, # Prompt with <image> tokens
"prompt_token_ids": prompt_token_ids # Token IDs for consistency
}
return
34 changes: 22 additions & 12 deletions components/src/dynamo/trtllm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,35 @@

import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional

from tensorrt_llm import LLM
from typing import AsyncGenerator, Optional, Union

from tensorrt_llm import LLM, MultimodalEncoder
from dynamo.trtllm.constants import DisaggregationMode
logging.basicConfig(level=logging.DEBUG)


class TensorRTLLMEngine:
def __init__(self, engine_args):
def __init__(self, engine_args, disaggregation_mode=None):
self.engine_args = engine_args
self._llm: Optional[LLM] = None
self._llm: Optional[Union[LLM, MultimodalEncoder]] = None
self.disaggregation_mode = disaggregation_mode

async def initialize(self):
if not self._llm:
model = self.engine_args.pop("model")
self._llm = LLM(
model=model,
**self.engine_args,
)
if self.disaggregation_mode == DisaggregationMode.ENCODE:
# Initialize MultimodalEncoder for EPD flow
max_batch_size = self.engine_args.pop("max_batch_size", 1)
logging.info(f"Initializing MultimodalEncoder with max_batch_size={max_batch_size}")
self._llm = MultimodalEncoder(
model=model,
max_batch_size=max_batch_size,
)
else:
self._llm = LLM(
model=model,
**self.engine_args,
)

async def cleanup(self):
if self._llm:
Expand All @@ -33,15 +43,15 @@ async def cleanup(self):
self._llm = None

@property
def llm(self):
def llm(self) -> Union[LLM, MultimodalEncoder]:
if not self._llm:
raise RuntimeError("Engine not initialized")
return self._llm


@asynccontextmanager
async def get_llm_engine(engine_args) -> AsyncGenerator[TensorRTLLMEngine, None]:
engine = TensorRTLLMEngine(engine_args)
async def get_llm_engine(engine_args, disaggregation_mode=None) -> AsyncGenerator[TensorRTLLMEngine, None]:
engine = TensorRTLLMEngine(engine_args, disaggregation_mode)
try:
await engine.initialize()
yield engine
Expand Down
21 changes: 14 additions & 7 deletions components/src/dynamo/trtllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from torch.cuda import device_count
from transformers import AutoConfig
from transformers import AutoConfig, GenerationConfig

import dynamo.nixl_connect as nixl_connect
from dynamo.common.config_dump import dump_config
Expand Down Expand Up @@ -262,8 +262,17 @@ async def init(runtime: DistributedRuntime, config: Config):

# Populate default sampling params from the model
tokenizer = tokenizer_factory(arg_map["model"])

# Load HF model config and generation config for SamplingParams._setup()
hf_model_config = AutoConfig.from_pretrained(
arg_map["model"], trust_remote_code=True
)
generation_config = GenerationConfig.from_pretrained(
arg_map["model"], trust_remote_code=True
)

default_sampling_params = SamplingParams()
default_sampling_params._setup(tokenizer)
default_sampling_params._setup(tokenizer, hf_model_config, generation_config)
default_sampling_params.stop = None
model_input = ModelInput.Tokens
model_type = ModelType.Chat | ModelType.Completions
Expand All @@ -278,11 +287,9 @@ async def init(runtime: DistributedRuntime, config: Config):
if modality == "multimodal":
engine_args["skip_tokenizer_init"] = False
model_input = ModelInput.Text
model_config = AutoConfig.from_pretrained(
config.model_path, trust_remote_code=True
)
# Reuse the hf_model_config loaded earlier
multimodal_processor = MultimodalRequestProcessor(
model_type=model_config.model_type,
model_type=hf_model_config.model_type,
model_dir=config.model_path,
max_file_size_mb=config.max_file_size_mb,
tokenizer=tokenizer,
Expand All @@ -302,7 +309,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config}
)

async with get_llm_engine(engine_args) as engine:
async with get_llm_engine(engine_args, config.disaggregation_mode) as engine:
endpoint = component.endpoint(config.endpoint)

# should ideally call get_engine_runtime_config
Expand Down
Loading