Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Working EPD
Signed-off-by: Indrajit Bhosale <[email protected]>
  • Loading branch information
indrajit96 committed Nov 5, 2025
commit 62cd533169d0d2890bd04e7ee71f817d8e576653
6 changes: 6 additions & 0 deletions components/src/dynamo/trtllm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def llm(self) -> Union[LLM, MultimodalEncoder]:

@asynccontextmanager
async def get_llm_engine(engine_args, disaggregation_mode=None) -> AsyncGenerator[TensorRTLLMEngine, None]:
if disaggregation_mode == DisaggregationMode.DECODE:
logging.info(f"Disaggregation mode: {disaggregation_mode}, setting disable_overlap_scheduler to False")
engine_args["disable_overlap_scheduler"] = False
elif disaggregation_mode == DisaggregationMode.PREFILL:
logging.info(f"Disaggregation mode: {disaggregation_mode}, setting disable_overlap_scheduler to True")
engine_args["disable_overlap_scheduler"] = True
engine = TensorRTLLMEngine(engine_args, disaggregation_mode)
try:
await engine.initialize()
Expand Down
5 changes: 3 additions & 2 deletions components/src/dynamo/trtllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,11 @@ async def init(runtime: DistributedRuntime, config: Config):
)
modality = getattr(config, "modality", None) or "text"
# Prefill worker needs overlap scheduler disabled for disaggregation
disable_overlap_scheduler = False
disable_overlap_scheduler = True
if config.disaggregation_mode == DisaggregationMode.PREFILL:
disable_overlap_scheduler = True
disable_overlap_scheduler = False

logging.info(f"disable_overlap_scheduler: {disable_overlap_scheduler} for config.disaggregation_mode: {config.disaggregation_mode}")
arg_map = {
"model": model_path,
"scheduler_config": scheduler_config,
Expand Down
57 changes: 50 additions & 7 deletions components/src/dynamo/trtllm/request_handlers/handler_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,27 @@ async def generate_locally(
# is available for full EPD flow
disaggregated_params = None

# Normalize OpenAI request format BEFORE processing
# This ensures max_tokens is in stop_conditions when we need to save it
if "stop_conditions" not in request:
request["stop_conditions"] = {}
if "max_tokens" in request and "max_tokens" not in request["stop_conditions"]:
request["stop_conditions"]["max_tokens"] = request.pop("max_tokens")
logging.info(f"Normalized OpenAI max_tokens to stop_conditions: {request['stop_conditions']['max_tokens']}")

if "sampling_options" not in request:
request["sampling_options"] = {}
if "temperature" in request and "temperature" not in request["sampling_options"]:
request["sampling_options"]["temperature"] = request.pop("temperature")

if self.disaggregation_mode == DisaggregationMode.PREFILL:
if "stop_conditions" not in request:
request["stop_conditions"] = {}
# Save original max_tokens before modifying for prefill
# Store original max_tokens so decode worker can restore it
if "max_tokens" in request["stop_conditions"]:
request["_original_max_tokens"] = request["stop_conditions"]["max_tokens"]
logging.info(f"PREFILL: Saved original max_tokens: {request['_original_max_tokens']}")
else:
logging.info(f"PREFILL: No max_tokens in request stop_conditions")
request["stop_conditions"]["max_tokens"] = 1
if ep_disaggregated_params:
ep_disaggregated_params.request_type = "context_only"
Expand Down Expand Up @@ -191,6 +209,15 @@ async def generate_locally(
# Check for multimodal request and process it
# Now ep_disaggregated_params is properly set for both prefill and decode modes
if self.disaggregation_mode == DisaggregationMode.DECODE:
# Restore original max_tokens for decode phase
if "_original_max_tokens" in request:
if "stop_conditions" not in request:
request["stop_conditions"] = {}
request["stop_conditions"]["max_tokens"] = request["_original_max_tokens"]
logging.info(f"DECODE: Restored original max_tokens: {request['_original_max_tokens']}")
else:
logging.info(f"DECODE: No _original_max_tokens in request. Current max_tokens: {request.get('stop_conditions', {}).get('max_tokens', 'NOT SET')}")

# Decode worker with generation_only mode
# Pass the same inputs format as prefill
if "_epd_processed_prompt" in request:
Expand Down Expand Up @@ -271,13 +298,23 @@ async def generate_locally(
processors = [HelloWorldLogitsProcessor(self.engine.llm.tokenizer)]
adapters = create_trtllm_adapters(processors)
sampling_params.logits_processor = adapters

if self.disaggregation_mode == DisaggregationMode.DECODE:
logging.info(f"Generate Called for DECODE mode")
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
logging.info(f"Generate Called for PREFILL mode")
else:
logging.info(f"Generate Called for ENCODE mode")
logging.info(f"Generate locally: processed_input: {processed_input}")
logging.info(f"Generate locally: sampling_params: {sampling_params}")
logging.info(f"Generate locally: disaggregated_params: {disaggregated_params}")
logging.info(f"Generate locally: streaming: {streaming}")
generation_result = self.engine.llm.generate_async(
inputs=processed_input,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=streaming,
)
logging.info(f"Generate locally: generation_result: {generation_result}")

# Use the context manager to handle cancellation monitoring
async with self._cancellation_monitor(generation_result, context):
Expand Down Expand Up @@ -325,10 +362,16 @@ async def generate_locally(
out["disaggregated_params"] = asdict(encoded_params)

# Pass the processed prompt and token IDs for decode worker
if "_epd_processed_prompt" in request:
out["_epd_processed_prompt"] = request["_epd_processed_prompt"]
if "_epd_prompt_token_ids" in request and request["_epd_prompt_token_ids"]:
out["_epd_prompt_token_ids"] = request["_epd_prompt_token_ids"]
# Use the actual prompt and token IDs from the RequestOutput (res)
# which includes all the image placeholder tokens processed by TRTLLM
if "_epd_processed_prompt" in request and res.prompt:
out["_epd_processed_prompt"] = res.prompt
if "_epd_prompt_token_ids" in request and res.prompt_token_ids:
out["_epd_prompt_token_ids"] = res.prompt_token_ids

# Pass the original max_tokens to decode worker
if "_original_max_tokens" in request:
out["_original_max_tokens"] = request["_original_max_tokens"]

if res.finished and not out.get("finish_reason"):
out["finish_reason"] = "unknown"
Expand Down
7 changes: 6 additions & 1 deletion components/src/dynamo/trtllm/request_handlers/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def remote_decode(self, request: dict, context: Context):

async def generate(self, request: dict, context: Context):
logging.debug(f"New Request ID: {context.id()}")
logging.debug(f"PrefillHandler.generate received request: {request}")
logging.info(f"PrefillHandler.generate received request with stop_conditions: {request.get('stop_conditions', 'NOT SET')}")
embeddings_tensor = None
ep_disaggregated_params = None

Expand Down Expand Up @@ -190,11 +190,13 @@ async def generate(self, request: dict, context: Context):
prefill_request = copy.deepcopy(request)
prefill_response = None
response_count = 0
logging.info(f"PrefillHandler: Local generate request: {prefill_request}")
async for res in self.generate_locally(
prefill_request, context, embeddings_tensor, ep_disaggregated_params
):
prefill_response = res
response_count += 1
logging.info(f"PrefillHandler: Local generate response: {res}")
if response_count > 1:
raise ValueError("Prefill response should be generated only once.")

Expand Down Expand Up @@ -298,6 +300,9 @@ async def generate(self, request: dict, context: Context):
# Extract pre-computed token IDs from encoder for consistency
if "_epd_prompt_token_ids" in response_data and response_data["_epd_prompt_token_ids"]:
request["_epd_prompt_token_ids"] = response_data["_epd_prompt_token_ids"]
# Extract original max_tokens for decode phase
if "_original_max_tokens" in response_data:
request["_original_max_tokens"] = response_data["_original_max_tokens"]

async for res in self.generate_locally(request, context):
yield res