Skip to content
107 changes: 101 additions & 6 deletions components/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import logging
import os
import tempfile
import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final
from typing import Any, AsyncGenerator, Dict, Final, List, Optional

from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput
Expand Down Expand Up @@ -73,6 +74,14 @@ def build_sampling_params(
return sampling_params


def _should_include_timing_metrics(request: Dict[str, Any]) -> bool:
"""Check if timing_metrics is requested in extra_fields."""
extra_fields: Optional[List[str]] = request.get("extra_fields")
if extra_fields is None:
return False
return "timing_metrics" in extra_fields


class BaseWorkerHandler(ABC):
"""
Request handler for the generate and clear_kv_blocks endpoints.
Expand Down Expand Up @@ -253,7 +262,7 @@ async def generate_tokens(
out[
"completion_usage"
] = BaseWorkerHandler._build_completion_usage(
request_output=res
request_output=res,
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
Expand Down Expand Up @@ -296,6 +305,18 @@ async def generate(self, request, context):
request_id = context.id()
logger.debug(f"Decode Request ID: {request_id}")

# Check if timing metrics are requested
include_timing = _should_include_timing_metrics(request)

# Initialize timing metrics using request_received_seconds from frontend (passed via PreprocessedRequest)
timing_metrics: Optional[Dict[str, float]] = None
if include_timing:
timing_metrics = {}
# Use request_received_seconds from the request (set by frontend) if available
frontend_received = request.get("request_received_seconds")
if frontend_received is not None:
timing_metrics["request_received_seconds"] = frontend_received

# Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request)

Expand All @@ -313,6 +334,17 @@ async def generate(self, request, context):
kv_params = prefill_result.get("disaggregated_params", {}).get(
"kv_transfer_params"
)
# Extract prefill timing from prefill_result if available
if include_timing:
prefill_timing = prefill_result.get("disaggregated_params", {}).get(
"timing_metrics"
)
if prefill_timing:
# Merge prefill timing but keep the frontend's request_received_seconds
received = timing_metrics.get("request_received_seconds")
timing_metrics.update(prefill_timing)
if received is not None:
timing_metrics["request_received_seconds"] = received
else:
kv_params = None

Expand All @@ -329,15 +361,51 @@ async def generate(self, request, context):

dp_rank = request.get("dp_rank", None)

# Track decode timing
first_token_sent = False

async with self._abort_monitor(context, request_id):
try:
# Record decode start time
if include_timing:
decode_start_seconds = time.time()
# If this is aggregated mode (no prefill_result), prefill_start == decode_start
if prefill_result is None:
timing_metrics["prefill_start_seconds"] = decode_start_seconds
timing_metrics["decode_start_seconds"] = decode_start_seconds

async for tok in self.generate_tokens(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank
):
# Capture first token timing
if include_timing and not first_token_sent:
first_token_time = time.time()
timing_metrics["decode_first_token_seconds"] = first_token_time
# In aggregated mode, prefill finishes when first token is generated
if prefill_result is None:
timing_metrics["prefill_end_seconds"] = first_token_time
first_token_sent = True

if prefill_result is not None and "completion_usage" in tok:
tok["completion_usage"][
"prompt_tokens_details"
] = prefill_prompt_tokens_details

# On finish, record decode_end_seconds and inject timing_metrics
# Note: request_finish_seconds is set in the Rust HTTP layer when the response actually leaves the server
if tok.get("finish_reason") is not None and include_timing:
timing_metrics[
"decode_end_seconds"
] = time.time()

# Inject timing_metrics into disaggregated_params
if (
"disaggregated_params" not in tok
or tok["disaggregated_params"] is None
):
tok["disaggregated_params"] = {}
tok["disaggregated_params"]["timing_metrics"] = timing_metrics

yield tok
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
Expand Down Expand Up @@ -370,6 +438,21 @@ async def generate(self, request, context):
request_id = context.id()
logger.debug(f"Prefill Request ID: {request_id}")

# Check if timing metrics are requested
include_timing = _should_include_timing_metrics(request)

# Initialize timing metrics using request_received_seconds from frontend (passed via PreprocessedRequest)
timing_metrics: Optional[Dict[str, float]] = None
if include_timing:
timing_metrics = {}
# Use request_received_seconds from the request (set by frontend) if available
frontend_received = request.get("request_received_seconds")
if frontend_received is not None:
timing_metrics["request_received_seconds"] = frontend_received

# Record prefill_start as when we start processing in the prefill worker
timing_metrics["prefill_start_seconds"] = time.time()

# Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request)

Expand Down Expand Up @@ -422,15 +505,27 @@ async def generate(self, request, context):

token_ids = res.outputs[0].token_ids if res.outputs else []

# Build disaggregated_params with kv_transfer_params and timing_metrics
disaggregated_params: Optional[Dict[str, Any]] = {}

if res.kv_transfer_params:
disaggregated_params[
"kv_transfer_params"
] = res.kv_transfer_params

if include_timing and timing_metrics:
timing_metrics[
"prefill_end_seconds"
] = time.time()
disaggregated_params["timing_metrics"] = timing_metrics

output: Dict[str, Any] = {
"token_ids": list(token_ids),
"disaggregated_params": (
{"kv_transfer_params": res.kv_transfer_params}
if res.kv_transfer_params
else None
disaggregated_params if disaggregated_params else None
),
"completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res
request_output=res,
),
}

Expand Down
208 changes: 208 additions & 0 deletions components/src/dynamo/vllm/tests/test_vllm_extra_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Unit tests for extra_fields handling in vLLM handlers."""

import asyncio
import warnings
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

# Filter Pydantic deprecation warning before importing handlers
warnings.filterwarnings(
"ignore",
message=".*json_encoders.*is deprecated.*",
category=DeprecationWarning,
)

from dynamo.vllm.handlers import ( # noqa: E402
DecodeWorkerHandler,
PrefillWorkerHandler,
_should_include_timing_metrics,
)


pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_1,
pytest.mark.pre_merge,
]


class TestShouldIncludeTimingMetrics:
"""Tests for _should_include_timing_metrics helper function."""

def test_returns_true_with_multiple_extra_fields(self):
"""Timing metrics should be included when explicitly requested."""
request = {"extra_fields": ["worker_id", "timing_metrics", "other_field"]}
assert _should_include_timing_metrics(request) is True

def test_returns_false_when_extra_fields_is_none(self):
"""Timing metrics should not be included when extra_fields is None."""
request = {"extra_fields": None}
assert _should_include_timing_metrics(request) is False

def test_returns_false_when_extra_fields_missing(self):
"""Timing metrics should not be included when extra_fields key is absent."""
request: dict[str, list[str]] = {}
assert _should_include_timing_metrics(request) is False


def make_mock_request_output(
token_ids: list[int],
finish_reason: str | None = None,
prompt_token_ids: list[int] | None = None,
):
"""Create a mock vLLM RequestOutput."""
output = MagicMock()
output.token_ids = token_ids
output.finish_reason = finish_reason
output.stop_reason = None

request_output = MagicMock()
request_output.outputs = [output]
request_output.prompt_token_ids = prompt_token_ids or [1, 2, 3]
request_output.num_cached_tokens = 0
request_output.kv_transfer_params = None
return request_output


def create_mock_handler(handler_class: type):
"""Create a handler with mocked dependencies."""
runtime = MagicMock()
component = MagicMock()
engine = MagicMock()
default_sampling_params: dict[str, str] = {}

with patch("dynamo.vllm.handlers.VllmEngineMonitor"):
with patch("dynamo.vllm.handlers.ImageLoader"):
handler = handler_class(
runtime=runtime,
component=component,
engine=engine,
default_sampling_params=default_sampling_params,
model_max_len=4096,
)
return handler


def create_mock_context(request_id: str = "test-request-123"):
"""Create a mock context that doesn't trigger abort."""
context = MagicMock()
context.id.return_value = request_id
# Make async_killed_or_stopped hang forever (never abort)
context.async_killed_or_stopped = AsyncMock(side_effect=asyncio.CancelledError)
return context


class TestDecodeWorkerHandlerTiming:
"""E2E tests for timing metrics in DecodeWorkerHandler."""

@pytest.mark.asyncio
async def test_no_timing_metrics_when_not_requested(self):
"""When timing_metrics not requested, no timing data in output."""
handler = create_mock_handler(DecodeWorkerHandler)
context = create_mock_context()

final_output = make_mock_request_output([100], finish_reason="stop")

async def mock_generate(*args, **kwargs):
yield final_output

handler.engine_client.generate = mock_generate

request = {
"token_ids": [1, 2, 3],
"sampling_options": {},
"stop_conditions": {},
}

results = []
async for output in handler.generate(request, context):
results.append(output)

final = results[-1]
assert (
final.get("disaggregated_params") is None
or final.get("disaggregated_params", {}).get("timing_metrics") is None
)

@pytest.mark.asyncio
async def test_disaggregated_mode_preserves_frontend_timestamp(self):
"""In disaggregated mode, frontend's request_received_seconds is preserved."""
handler = create_mock_handler(DecodeWorkerHandler)
context = create_mock_context()

final_output = make_mock_request_output([100], finish_reason="stop")

async def mock_generate(*args, **kwargs):
yield final_output

handler.engine_client.generate = mock_generate

request = {
"token_ids": [1, 2, 3],
"sampling_options": {},
"stop_conditions": {},
"extra_fields": ["timing_metrics"],
"request_received_seconds": 1000.0,
"prefill_result": {
"disaggregated_params": {
"timing_metrics": {
"request_received_seconds": 999.0,
"prefill_start_seconds": 1001.0,
"prefill_end_seconds": 1002.0,
}
}
},
}

results = []
async for output in handler.generate(request, context):
results.append(output)

timing = results[-1]["disaggregated_params"]["timing_metrics"]

# Frontend's timestamp must be preserved
assert timing["request_received_seconds"] == 1000.0
# Prefill timing should be merged
assert timing["prefill_start_seconds"] == 1001.0
assert timing["prefill_end_seconds"] == 1002.0


class TestPrefillWorkerHandlerTiming:
"""E2E tests for timing metrics in PrefillWorkerHandler."""

@pytest.mark.asyncio
async def test_timing_metrics_included_in_prefill_output(self):
"""When timing_metrics requested, prefill output contains timing data."""
handler = create_mock_handler(PrefillWorkerHandler)
context = create_mock_context()

prefill_output = make_mock_request_output([100])
prefill_output.kv_transfer_params = {"some": "params"}

async def mock_generate(*args, **kwargs):
yield prefill_output

handler.engine_client.generate = mock_generate

request = {
"token_ids": [1, 2, 3],
"sampling_options": {},
"stop_conditions": {},
"extra_fields": ["timing_metrics"],
"request_received_seconds": 1000.0,
}

results = []
async for output in handler.generate(request, context):
results.append(output)

timing = results[-1]["disaggregated_params"]["timing_metrics"]

assert timing["request_received_seconds"] == 1000.0
assert "prefill_start_seconds" in timing
assert "prefill_end_seconds" in timing
Loading
Loading