Skip to content
Open
125 changes: 119 additions & 6 deletions components/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import logging
import os
import tempfile
import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final
from typing import Any, AsyncGenerator, Dict, Final, List, Optional

from vllm.inputs import TokensPrompt
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -128,6 +129,14 @@ def build_sampling_params(
return sampling_params


def _request_contains_timing_metrics(request: Dict[str, Any]) -> bool:
"""Check if timing_metrics is requested in observability_fields."""
observability_fields: Optional[List[str]] = request.get("observability_fields")
if observability_fields is None:
return False
return "timing_metrics" in observability_fields


class BaseWorkerHandler(ABC):
"""
Request handler for the generate and clear_kv_blocks endpoints.
Expand Down Expand Up @@ -627,7 +636,7 @@ async def generate_tokens(
out[
"completion_usage"
] = BaseWorkerHandler._build_completion_usage(
request_output=res
request_output=res,
)
# Log completion with LoRA info (debug level to avoid log spam)
if lora_request:
Expand Down Expand Up @@ -686,6 +695,29 @@ async def generate(self, request, context):
request_id = context.id()
logger.debug(f"Decode Request ID: {request_id}")

# Check if timing metrics are requested
include_timing = _request_contains_timing_metrics(request)

# Initialize timing metrics using request_received_seconds from frontend (passed via PreprocessedRequest)
#
# TIMING METRICS:
# - Reliable durations: Use same-machine timestamps (e.g., decode_end - decode_start).
# We use time.perf_counter() for intra-worker duration calculations to ensure monotonic,
# high-resolution timing that's immune to system clock adjustments.
# - Cross-machine calculations (e.g., prefill_start - request_received) assume perfect NTP
# synchronization and should be used with UTMOST CAUTION due to clock drift. Even with NTP,
# clocks can drift by milliseconds each day, leading to negative durations or misleading latency values.
# These cross-machine metrics are useful for rough end-to-end analysis but should not be
# relied upon for precise performance measurements.
# - TODO: Measure actual overhead (network, queueing, etc.) - expected to be low but needs
# benchmarking
timing_metrics: Dict[str, float] = {}
if include_timing:
# Use request_received_seconds from the request (set by frontend) if available
frontend_received = request.get("request_received_seconds")
if frontend_received is not None:
timing_metrics["request_received_seconds"] = frontend_received

# Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request)

Expand All @@ -703,6 +735,17 @@ async def generate(self, request, context):
kv_params = prefill_result.get("disaggregated_params", {}).get(
"kv_transfer_params"
)
# Extract prefill timing from prefill_result if available
if include_timing:
prefill_timing = prefill_result.get("disaggregated_params", {}).get(
"timing_metrics"
)
if prefill_timing:
# Merge prefill timing but keep the frontend's request_received_seconds
received = timing_metrics.get("request_received_seconds")
timing_metrics.update(prefill_timing)
if received is not None:
timing_metrics["request_received_seconds"] = received
else:
kv_params = None

Expand Down Expand Up @@ -739,19 +782,58 @@ async def generate(self, request, context):

dp_rank = request.get("dp_rank", None)

# Track decode timing
first_token_sent = False

async with self._abort_monitor(context, request_id):
try:
# Record decode start time
if include_timing:
decode_start_seconds = time.time()
decode_start_perf_counter = time.perf_counter()
# If this is aggregated mode (no prefill_result), prefill_start == decode_start
if prefill_result is None:
timing_metrics["prefill_start_seconds"] = decode_start_seconds
timing_metrics["decode_start_seconds"] = decode_start_seconds

async for tok in self.generate_tokens(
prompt,
sampling_params,
request_id,
data_parallel_rank=dp_rank,
lora_request=lora_request,
):
# Capture first token timing
if include_timing and not first_token_sent:
first_token_time = decode_start_seconds + (
time.perf_counter() - decode_start_perf_counter
)
timing_metrics["decode_first_token_seconds"] = first_token_time
# In aggregated mode, prefill finishes when first token is generated
if prefill_result is None:
timing_metrics["prefill_end_seconds"] = first_token_time
first_token_sent = True

if prefill_result is not None and "completion_usage" in tok:
tok["completion_usage"][
"prompt_tokens_details"
] = prefill_prompt_tokens_details

# On finish, record decode_end_seconds and inject timing_metrics
# Note: request_finish_seconds is set in the Rust HTTP layer when the response actually leaves the server
if tok.get("finish_reason") is not None and include_timing:
timing_metrics["decode_end_seconds"] = decode_start_seconds + (
time.perf_counter() - decode_start_perf_counter
)

# Inject timing_metrics into disaggregated_params
if (
"disaggregated_params" not in tok
or tok["disaggregated_params"] is None
):
tok["disaggregated_params"] = {}
tok["disaggregated_params"]["timing_metrics"] = timing_metrics

yield tok
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
Expand Down Expand Up @@ -788,6 +870,23 @@ async def generate(self, request, context):
request_id = context.id()
logger.debug(f"Prefill Request ID: {request_id}")

# Check if timing metrics are requested
include_timing = _request_contains_timing_metrics(request)

# Initialize timing metrics using request_received_seconds from frontend (passed via PreprocessedRequest)
# See DecodeWorkerHandler.generate() for timing metrics documentation
timing_metrics: Dict[str, float] = {}
if include_timing:
# Use request_received_seconds from the request (set by frontend) if available
frontend_received = request.get("request_received_seconds")
if frontend_received is not None:
timing_metrics["request_received_seconds"] = frontend_received

# Record prefill_start as when we start processing in the prefill worker
prefill_start_seconds = time.time()
prefill_start_perf_counter = time.perf_counter()
timing_metrics["prefill_start_seconds"] = prefill_start_seconds

# Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request)

Expand Down Expand Up @@ -865,15 +964,29 @@ async def generate(self, request, context):

token_ids = res.outputs[0].token_ids if res.outputs else []

# Build disaggregated_params with kv_transfer_params and timing_metrics
disaggregated_params: Optional[Dict[str, Any]] = {}

if res.kv_transfer_params:
disaggregated_params[
"kv_transfer_params"
] = res.kv_transfer_params

if include_timing and timing_metrics:
timing_metrics[
"prefill_end_seconds"
] = prefill_start_seconds + (
time.perf_counter() - prefill_start_perf_counter
)
disaggregated_params["timing_metrics"] = timing_metrics

output: Dict[str, Any] = {
"token_ids": list(token_ids),
"disaggregated_params": (
{"kv_transfer_params": res.kv_transfer_params}
if res.kv_transfer_params
else None
disaggregated_params if disaggregated_params else None
),
"completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res
request_output=res,
),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ async def stream_response(
if request.stream:
# Handle streaming response
num_output_text_so_far = 0
async for raw_response in self.openai_serving.chat_completion_stream_generator(
async for (
raw_response
) in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
Expand Down Expand Up @@ -212,7 +214,9 @@ async def stream_response(
# Collect all chunks into a single response
full_response = None
num_output_text_so_far = 0
async for raw_response in self.openai_serving.chat_completion_stream_generator(
async for (
raw_response
) in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
Expand Down
Loading
Loading