Skip to content
Prev Previous commit
Next Next commit
Fixes
Signed-off-by: Zhongxuan Wang <[email protected]>
  • Loading branch information
zhongxuanwang-nv committed Dec 1, 2025
commit ded8aa8c48a10f7a14fdd22685c73cabb3e81c44
53 changes: 32 additions & 21 deletions lib/llm/src/http/service/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,32 +58,25 @@ pub const ANNOTATION_REQUEST_ID: &str = "request_id";

/// Injects `request_completed_seconds` into the nvext timing_metrics field.
/// This captures the exact moment when the response is about to leave the server.
/// Only injects if timing_metrics already exists (i.e., the user requested it via extra_fields).
fn inject_request_completed_seconds(nvext: &mut Option<serde_json::Value>) {
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs_f64())
.ok();

if let Some(ts) = ts {
let nvext = nvext.get_or_insert_with(|| serde_json::Value::Object(serde_json::Map::new()));
if let Some(obj) = nvext.as_object_mut() {
if let Some(timing) = obj.get_mut("timing_metrics") {
if let Some(timing_obj) = timing.as_object_mut() {
timing_obj.insert(
"request_completed_seconds".to_string(),
serde_json::Value::from(ts),
);
// Only inject if nvext and timing_metrics already exist (user requested timing_metrics)
if let Some(nvext) = nvext.as_mut() {
if let Some(obj) = nvext.as_object_mut() {
if let Some(timing) = obj.get_mut("timing_metrics") {
if let Some(timing_obj) = timing.as_object_mut() {
timing_obj.insert(
"request_completed_seconds".to_string(),
serde_json::Value::from(ts),
);
}
}
} else {
let mut timing_obj = serde_json::Map::new();
timing_obj.insert(
"request_completed_seconds".to_string(),
serde_json::Value::from(ts),
);
obj.insert(
"timing_metrics".to_string(),
serde_json::Value::Object(timing_obj),
);
}
}
}
Expand Down Expand Up @@ -2173,13 +2166,31 @@ mod tests {
}

#[test]
fn test_inject_request_completed_seconds_creates_nvext_if_none() {
fn test_inject_request_completed_seconds_does_not_create_nvext_if_none() {
// If nvext is None (user didn't request timing_metrics), we should NOT create it
let mut nvext: Option<serde_json::Value> = None;

inject_request_completed_seconds(&mut nvext);

assert!(
nvext.is_none(),
"nvext should remain None when timing_metrics was not requested"
);
}

#[test]
fn test_inject_request_completed_seconds_does_not_create_timing_metrics_if_missing() {
// If nvext exists but timing_metrics is not present, we should NOT create it
let mut nvext = Some(serde_json::json!({
"worker_id": {"prefill_worker_id": 42}
}));

inject_request_completed_seconds(&mut nvext);

let nvext = nvext.unwrap();
let timing = nvext.get("timing_metrics").unwrap();
assert!(timing.get("request_completed_seconds").is_some());
assert!(
nvext.get("timing_metrics").is_none(),
"timing_metrics should not be created if not already present"
);
}
}
29 changes: 23 additions & 6 deletions lib/llm/src/protocols/openai/chat_completions/delta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ impl NvCreateChatCompletionRequest {
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(false)
|| self.inner.top_logprobs.unwrap_or(0) > 0,
extra_fields: self.nvext.as_ref().and_then(|nv| nv.extra_fields.clone()),
runtime_config: ModelRuntimeConfig::default(),
};

Expand All @@ -64,6 +65,8 @@ pub struct DeltaGeneratorOptions {
pub enable_usage: bool,
/// Determines whether log probabilities should be included in the response.
pub enable_logprobs: bool,
/// Extra fields to include in response nvext (e.g., "worker_id", "timing_metrics")
pub extra_fields: Option<Vec<String>>,

pub runtime_config: ModelRuntimeConfig,
}
Expand Down Expand Up @@ -288,6 +291,15 @@ impl DeltaGenerator {
pub fn is_usage_enabled(&self) -> bool {
self.options.enable_usage
}

/// Check if an extra field is requested
fn is_extra_field_requested(&self, field: &str) -> bool {
self.options
.extra_fields
.as_ref()
.map(|fields| fields.iter().any(|f| f == field))
.unwrap_or(false)
}
}

/// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
Expand Down Expand Up @@ -363,17 +375,22 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);

// Extract worker_id and timing_metrics from disaggregated_params and inject into nvext
// Only include fields that were explicitly requested via extra_fields
if let Some(ref disaggregated_params) = delta.disaggregated_params {
let mut nvext_obj = serde_json::Map::new();

// Extract worker_id if present
if let Some(worker_id_json) = disaggregated_params.get("worker_id") {
nvext_obj.insert("worker_id".to_string(), worker_id_json.clone());
// Extract worker_id if present and requested
if self.is_extra_field_requested("worker_id") {
if let Some(worker_id_json) = disaggregated_params.get("worker_id") {
nvext_obj.insert("worker_id".to_string(), worker_id_json.clone());
}
}

// Extract timing_metrics if present (pass through as-is)
if let Some(timing_metrics_json) = disaggregated_params.get("timing_metrics") {
nvext_obj.insert("timing_metrics".to_string(), timing_metrics_json.clone());
// Extract timing_metrics if present and requested
if self.is_extra_field_requested("timing_metrics") {
if let Some(timing_metrics_json) = disaggregated_params.get("timing_metrics") {
nvext_obj.insert("timing_metrics".to_string(), timing_metrics_json.clone());
}
}

// Only set nvext if we have at least one field
Expand Down
29 changes: 23 additions & 6 deletions lib/llm/src/protocols/openai/completions/delta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ impl NvCreateCompletionRequest {
.map(|opts| opts.include_usage)
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
extra_fields: self.nvext.as_ref().and_then(|nv| nv.extra_fields.clone()),
};

DeltaGenerator::new(self.inner.model.clone(), options, request_id)
Expand All @@ -51,6 +52,8 @@ impl NvCreateCompletionRequest {
pub struct DeltaGeneratorOptions {
pub enable_usage: bool,
pub enable_logprobs: bool,
/// Extra fields to include in response nvext (e.g., "worker_id", "timing_metrics")
pub extra_fields: Option<Vec<String>>,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -222,6 +225,15 @@ impl DeltaGenerator {
pub fn is_usage_enabled(&self) -> bool {
self.options.enable_usage
}

/// Check if an extra field is requested
fn is_extra_field_requested(&self, field: &str) -> bool {
self.options
.extra_fields
.as_ref()
.map(|fields| fields.iter().any(|f| f == field))
.unwrap_or(false)
}
}

impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
Expand Down Expand Up @@ -266,17 +278,22 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);

// Extract worker_id and timing_metrics from disaggregated_params and inject into nvext
// Only include fields that were explicitly requested via extra_fields
if let Some(ref disaggregated_params) = delta.disaggregated_params {
let mut nvext_obj = serde_json::Map::new();

// Extract worker_id if present
if let Some(worker_id_json) = disaggregated_params.get("worker_id") {
nvext_obj.insert("worker_id".to_string(), worker_id_json.clone());
// Extract worker_id if present and requested
if self.is_extra_field_requested("worker_id") {
if let Some(worker_id_json) = disaggregated_params.get("worker_id") {
nvext_obj.insert("worker_id".to_string(), worker_id_json.clone());
}
}

// Extract timing_metrics if present (pass through as-is)
if let Some(timing_metrics_json) = disaggregated_params.get("timing_metrics") {
nvext_obj.insert("timing_metrics".to_string(), timing_metrics_json.clone());
// Extract timing_metrics if present and requested
if self.is_extra_field_requested("timing_metrics") {
if let Some(timing_metrics_json) = disaggregated_params.get("timing_metrics") {
nvext_obj.insert("timing_metrics".to_string(), timing_metrics_json.clone());
}
}

// Only set nvext if we have at least one field
Expand Down
20 changes: 0 additions & 20 deletions lib/llm/src/protocols/openai/nvext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,6 @@ pub trait NvExtProvider {
fn raw_prompt(&self) -> Option<String>;
}

/// Worker ID information for disaggregated serving
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct WorkerIdInfo {
/// The prefill worker ID that processed this request
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,

/// The decode worker ID that processed this request
#[serde(skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
}

/// NVIDIA LLM response extensions
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct NvExtResponse {
/// Worker ID information (prefill and decode worker IDs)
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_id: Option<WorkerIdInfo>,
}

/// NVIDIA LLM extensions to the OpenAI API
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[validate(schema(function = "validate_nv_ext"))]
Expand Down
Loading