Skip to content
117 changes: 107 additions & 10 deletions components/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import logging
import os
import tempfile
import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final
from typing import Any, AsyncGenerator, Dict, Final, List, Optional

from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput
Expand Down Expand Up @@ -73,6 +74,19 @@ def build_sampling_params(
return sampling_params


def _should_include_timing_metrics(request: Dict[str, Any]) -> bool:
"""Check if timing_metrics is requested in extra_fields."""
extra_fields: Optional[List[str]] = request.get("extra_fields")
if extra_fields is None:
return False
return "timing_metrics" in extra_fields


def _get_current_time_seconds() -> float:
"""Get the current time in seconds since epoch as a float."""
return time.time()


class BaseWorkerHandler(ABC):
"""
Request handler for the generate and clear_kv_blocks endpoints.
Expand Down Expand Up @@ -250,10 +264,10 @@ async def generate_tokens(
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
out[
"completion_usage"
] = BaseWorkerHandler._build_completion_usage(
request_output=res
out["completion_usage"] = (
BaseWorkerHandler._build_completion_usage(
request_output=res
)
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
Expand Down Expand Up @@ -296,6 +310,18 @@ async def generate(self, request, context):
request_id = context.id()
logger.debug(f"Decode Request ID: {request_id}")

# Check if timing metrics are requested
include_timing = _should_include_timing_metrics(request)

# Initialize timing metrics using request_received_seconds from frontend (passed via PreprocessedRequest)
timing_metrics: Optional[Dict[str, float]] = None
if include_timing:
timing_metrics = {}
# Use request_received_seconds from the request (set by frontend) if available
frontend_received = request.get("request_received_seconds")
if frontend_received is not None:
timing_metrics["request_received_seconds"] = frontend_received

# Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request)

Expand All @@ -313,6 +339,17 @@ async def generate(self, request, context):
kv_params = prefill_result.get("disaggregated_params", {}).get(
"kv_transfer_params"
)
# Extract prefill timing from prefill_result if available
if include_timing:
prefill_timing = prefill_result.get("disaggregated_params", {}).get(
"timing_metrics"
)
if prefill_timing:
# Merge prefill timing but keep the frontend's request_received_seconds
received = timing_metrics.get("request_received_seconds")
timing_metrics.update(prefill_timing)
if received is not None:
timing_metrics["request_received_seconds"] = received
else:
kv_params = None

Expand All @@ -329,15 +366,52 @@ async def generate(self, request, context):

dp_rank = request.get("dp_rank", None)

# Track decode timing
first_token_sent = False
decode_start_seconds: Optional[float] = None

async with self._abort_monitor(context, request_id):
try:
# Record decode start time
if include_timing:
decode_start_seconds = _get_current_time_seconds()
# If this is aggregated mode (no prefill_result), prefill_start == decode_start
if prefill_result is None:
timing_metrics["prefill_start_seconds"] = decode_start_seconds
timing_metrics["decode_start_seconds"] = decode_start_seconds

async for tok in self.generate_tokens(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank
):
# Capture first token timing
if include_timing and not first_token_sent:
first_token_time = _get_current_time_seconds()
timing_metrics["decode_first_token_seconds"] = first_token_time
# If aggregated mode, prefill finishes when first token is generated
if prefill_result is None:
timing_metrics["prefill_end_seconds"] = first_token_time
first_token_sent = True

if prefill_result is not None and "completion_usage" in tok:
tok["completion_usage"][
"prompt_tokens_details"
] = prefill_prompt_tokens_details

# On finish, record decode_end_seconds and inject timing_metrics
# Note: request_finish_seconds is set in the Rust HTTP layer when the response actually leaves the server
if tok.get("finish_reason") is not None and include_timing:
timing_metrics["decode_end_seconds"] = (
_get_current_time_seconds()
)

# Inject timing_metrics into disaggregated_params
if (
"disaggregated_params" not in tok
or tok["disaggregated_params"] is None
):
tok["disaggregated_params"] = {}
tok["disaggregated_params"]["timing_metrics"] = timing_metrics

yield tok
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
Expand Down Expand Up @@ -370,6 +444,21 @@ async def generate(self, request, context):
request_id = context.id()
logger.debug(f"Prefill Request ID: {request_id}")

# Check if timing metrics are requested
include_timing = _should_include_timing_metrics(request)

# Initialize timing metrics using request_received_seconds from frontend (passed via PreprocessedRequest)
timing_metrics: Optional[Dict[str, float]] = None
if include_timing:
timing_metrics = {}
# Use request_received_seconds from the request (set by frontend) if available
frontend_received = request.get("request_received_seconds")
if frontend_received is not None:
timing_metrics["request_received_seconds"] = frontend_received

# Record prefill_start as when we start processing in the prefill worker
timing_metrics["prefill_start_seconds"] = _get_current_time_seconds()

# Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request)

Expand Down Expand Up @@ -422,13 +511,21 @@ async def generate(self, request, context):

token_ids = res.outputs[0].token_ids if res.outputs else []

# Build disaggregated_params with kv_transfer_params and timing_metrics
disaggregated_params: Optional[Dict[str, Any]] = {}

if res.kv_transfer_params:
disaggregated_params["kv_transfer_params"] = res.kv_transfer_params

if include_timing and timing_metrics:
timing_metrics["prefill_end_seconds"] = (
_get_current_time_seconds()
)
disaggregated_params["timing_metrics"] = timing_metrics

output: Dict[str, Any] = {
"token_ids": list(token_ids),
"disaggregated_params": (
{"kv_transfer_params": res.kv_transfer_params}
if res.kv_transfer_params
else None
),
"disaggregated_params": disaggregated_params if disaggregated_params else None,
"completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res
),
Expand Down
131 changes: 129 additions & 2 deletions lib/llm/src/http/service/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,39 @@ pub const DYNAMO_REQUEST_ID_HEADER: &str = "x-dynamo-request-id";
/// Dynamo Annotation for the request ID
pub const ANNOTATION_REQUEST_ID: &str = "request_id";

/// Injects `request_completed_seconds` into the nvext timing_metrics field.
/// This captures the exact moment when the response is about to leave the server.
fn inject_request_completed_seconds(nvext: &mut Option<serde_json::Value>) {
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs_f64())
.ok();

if let Some(ts) = ts {
let nvext = nvext.get_or_insert_with(|| serde_json::Value::Object(serde_json::Map::new()));
if let Some(obj) = nvext.as_object_mut() {
if let Some(timing) = obj.get_mut("timing_metrics") {
if let Some(timing_obj) = timing.as_object_mut() {
timing_obj.insert(
"request_completed_seconds".to_string(),
serde_json::Value::from(ts),
);
}
} else {
let mut timing_obj = serde_json::Map::new();
timing_obj.insert(
"request_completed_seconds".to_string(),
serde_json::Value::from(ts),
);
obj.insert(
"timing_metrics".to_string(),
serde_json::Value::Object(timing_obj),
);
}
}
}
}

// Default axum max body limit without configuring is 2MB: https://docs.rs/axum/latest/axum/extract/struct.DefaultBodyLimit.html
/// Default body limit in bytes (45MB) to support 500k+ token payloads.
/// Can be configured at compile time using the DYN_FRONTEND_BODY_LIMIT_MB environment variable
Expand Down Expand Up @@ -281,8 +314,20 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
async fn handler_completions(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateCompletionRequest>,
Json(mut request): Json<NvCreateCompletionRequest>,
) -> Result<Response, ErrorResponse> {
// Capture received timestamp immediately when request arrives at the frontend
let request_received_seconds = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs_f64())
.ok();

// Set request_received_seconds in nvext for timing metrics
if request_received_seconds.is_some() {
let nvext = request.nvext.get_or_insert_with(Default::default);
nvext.request_received_seconds = request_received_seconds;
}

// return a 503 if the service is not ready
check_ready(&state)?;

Expand Down Expand Up @@ -405,6 +450,23 @@ async fn completions_single(
// apply any annotations to the front of the stream
let stream = stream::iter(annotations).chain(stream);

// Inject request_completed_seconds into nvext when final content chunk is sent
// Only inject on the chunk with finish_reason (not the usage-only chunk)
// because timing_metrics is populated by the backend on the finish_reason chunk
let stream = stream.map(|mut annotated| {
if let Some(ref mut response) = annotated.data {
let has_finish_reason = response
.inner
.choices
.iter()
.any(|choice| choice.finish_reason.is_some());
if has_finish_reason {
inject_request_completed_seconds(&mut response.inner.nvext);
}
}
annotated
});

if streaming {
// For streaming, we'll drop the http_queue_guard on the first token
let mut http_queue_guard = Some(http_queue_guard);
Expand Down Expand Up @@ -686,8 +748,20 @@ async fn embeddings(
async fn handler_chat_completions(
State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
headers: HeaderMap,
Json(request): Json<NvCreateChatCompletionRequest>,
Json(mut request): Json<NvCreateChatCompletionRequest>,
) -> Result<Response, ErrorResponse> {
// Capture received timestamp immediately when request arrives at the frontend
let request_received_seconds = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs_f64())
.ok();

// Set request_received_seconds in nvext for timing metrics
if request_received_seconds.is_some() {
let nvext = request.nvext.get_or_insert_with(Default::default);
nvext.request_received_seconds = request_received_seconds;
}

// return a 503 if the service is not ready
check_ready(&state)?;

Expand Down Expand Up @@ -821,6 +895,22 @@ async fn chat_completions(
// todo - tap the stream and propagate request level metrics
// note - we might do this as part of the post processing set to make it more generic

// Inject request_completed_seconds into nvext when final content chunk is sent
// Only inject on the chunk with finish_reason (not the usage-only chunk)
// because timing_metrics is populated by the backend on the finish_reason chunk
let stream = stream.map(|mut annotated| {
if let Some(ref mut response) = annotated.data {
let has_finish_reason = response
.choices
.iter()
.any(|choice| choice.finish_reason.is_some());
if has_finish_reason {
inject_request_completed_seconds(&mut response.nvext);
}
}
annotated
});

if streaming {
stream_handle.arm(); // allows the system to detect client disconnects and cancel the LLM generation

Expand Down Expand Up @@ -2055,4 +2145,41 @@ mod tests {
assert!(msg.contains("response_format"));
}
}

// Tests for inject_request_completed_seconds function
#[test]
fn test_inject_request_completed_seconds_into_existing_timing_metrics() {
let mut nvext = Some(serde_json::json!({
"timing_metrics": {
"request_received_seconds": 1700000000.0,
"decode_end_seconds": 1700000001.5
}
}));

inject_request_completed_seconds(&mut nvext);

let nvext = nvext.unwrap();
let timing = nvext.get("timing_metrics").unwrap();
assert!(timing.get("request_completed_seconds").is_some());
// Verify existing fields are preserved
assert_eq!(
timing.get("request_received_seconds").unwrap().as_f64(),
Some(1700000000.0)
);
assert_eq!(
timing.get("decode_end_seconds").unwrap().as_f64(),
Some(1700000001.5)
);
}

#[test]
fn test_inject_request_completed_seconds_creates_nvext_if_none() {
let mut nvext: Option<serde_json::Value> = None;

inject_request_completed_seconds(&mut nvext);

let nvext = nvext.unwrap();
let timing = nvext.get("timing_metrics").unwrap();
assert!(timing.get("request_completed_seconds").is_some());
}
}
3 changes: 2 additions & 1 deletion lib/llm/src/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,11 @@ impl OpenAIPreprocessor {
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id and extra_fields from nvext if present
// Extract backend_instance_id, extra_fields, and request_received_seconds from nvext if present
if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id);
builder.extra_fields(nvext.extra_fields.clone());
builder.request_received_seconds(nvext.request_received_seconds);
}

Ok(builder)
Expand Down
6 changes: 6 additions & 0 deletions lib/llm/src/protocols/common/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ pub struct PreprocessedRequest {
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_fields: Option<Vec<String>>,

/// Timestamp when the request was received by the frontend (seconds since epoch)
/// Used for timing metrics to track end-to-end latency
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub request_received_seconds: Option<f64>,
}

impl PreprocessedRequest {
Expand Down
Loading
Loading