diff --git a/components/backends/trtllm/src/dynamo/trtllm/health_check.py b/components/backends/trtllm/src/dynamo/trtllm/health_check.py index c12beeae7fc..0178f6902ef 100644 --- a/components/backends/trtllm/src/dynamo/trtllm/health_check.py +++ b/components/backends/trtllm/src/dynamo/trtllm/health_check.py @@ -21,11 +21,27 @@ def __init__(self): """ Initialize TRT-LLM health check payload with TRT-LLM-specific defaults. """ - # Set TRT-LLM default payload - minimal request that completes quickly + # Set TensorRT-LLM default payload - minimal request that completes quickly + # The handler expects token_ids, stop_conditions, and sampling_options self.default_payload = { - "messages": [{"role": "user", "content": "1"}], - "max_tokens": 1, - "temperature": 0.0, - "stream": False, + "token_ids": [1], # Single token for minimal processing + "stop_conditions": { + "max_tokens": 1, # Generate only 1 token + "stop": None, + "stop_token_ids": None, + "include_stop_str_in_output": False, + "ignore_eos": False, + "min_tokens": 0, + }, + "sampling_options": { + "temperature": 0.0, + "top_p": 1.0, + "top_k": 1, + "beam_width": 1, + "repetition_penalty": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "seed": None, + }, } super().__init__() diff --git a/components/backends/trtllm/src/dynamo/trtllm/main.py b/components/backends/trtllm/src/dynamo/trtllm/main.py index 7c6b47958ac..2ecdad232c3 100644 --- a/components/backends/trtllm/src/dynamo/trtllm/main.py +++ b/components/backends/trtllm/src/dynamo/trtllm/main.py @@ -27,6 +27,7 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime.logging import configure_dynamo_logging from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine +from dynamo.trtllm.health_check import TrtllmHealthCheckPayload from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor from dynamo.trtllm.publisher import get_publisher from dynamo.trtllm.request_handlers.handlers import ( @@ -316,6 +317,9 @@ async def init(runtime: DistributedRuntime, config: Config): runtime_config=runtime_config, ) + # Get health check payload (checks env var and falls back to TensorRT-LLM default) + health_check_payload = TrtllmHealthCheckPayload().to_dict() + if config.publish_events_and_metrics and is_first_worker(config): # Initialize and pass in the publisher to the request handler to # publish events and metrics. @@ -334,11 +338,15 @@ async def init(runtime: DistributedRuntime, config: Config): handler_config.publisher = publisher handler = RequestHandlerFactory().get_request_handler(handler_config) await endpoint.serve_endpoint( - handler.generate, metrics_labels=metrics_labels + handler.generate, + metrics_labels=metrics_labels, + health_check_payload=health_check_payload, ) else: handler = RequestHandlerFactory().get_request_handler(handler_config) - await endpoint.serve_endpoint(handler.generate) + await endpoint.serve_endpoint( + handler.generate, health_check_payload=health_check_payload + ) def main():