Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/backends/vllm/launch/agg_multimodal.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ python -m dynamo.frontend --http-port=8000 &
EXTRA_ARGS=""
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
fi

# Start vLLM worker with vision model
Expand Down
1 change: 1 addition & 0 deletions lib/llm/src/preprocessor/prompt/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ struct HfTokenizerConfigJsonFormatter {
config: ChatTemplate,
mixins: Arc<ContextMixins>,
supports_add_generation_prompt: bool,
requires_content_arrays: bool,
}

// /// OpenAI Standard Prompt Formatter
Expand Down
35 changes: 34 additions & 1 deletion lib/llm/src/preprocessor/prompt/template/formatters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,38 @@ use std::sync::Arc;
use super::tokcfg::{ChatTemplate, raise_exception, strftime_now, tojson};
use super::{ContextMixins, HfTokenizerConfigJsonFormatter, JinjaEnvironment};
use either::Either;
use minijinja::{Environment, Value};
use minijinja::{Environment, Value, context};
use serde_json::json;
use tracing;

/// Detects if a template requires content as arrays (multimodal) vs strings (text-only).
/// Returns true if the template only works with array format.
fn detect_content_array_usage(env: &Environment) -> bool {
// Test with array format
let array_msg = context! {
messages => json!([{"role": "user", "content": [{"type": "text", "text": "template_test"}]}]),
add_generation_prompt => false,
};

// Test with string format
let string_msg = context! {
messages => json!([{"role": "user", "content": "template_test"}]),
add_generation_prompt => false,
};

let out_array = env
.get_template("default")
.and_then(|t| t.render(&array_msg))
.unwrap_or_default();
let out_string = env
.get_template("default")
.and_then(|t| t.render(&string_msg))
.unwrap_or_default();

// If array works but string doesn't, template requires arrays
out_array.contains("template_test") && !out_string.contains("template_test")
}

/// Remove known non-standard Jinja2 tags from chat templates
///
/// Some models use custom Jinja2 extensions that minijinja doesn't recognize. These tags
Expand Down Expand Up @@ -120,11 +149,15 @@ impl HfTokenizerConfigJsonFormatter {
}
}

// Detect at model load time whether this template requires content arrays
let requires_content_arrays = detect_content_array_usage(&env);

Ok(HfTokenizerConfigJsonFormatter {
env,
config,
mixins: Arc::new(mixins),
supports_add_generation_prompt: supports_add_generation_prompt.unwrap_or(false),
requires_content_arrays,
})
}
}
Expand Down
146 changes: 117 additions & 29 deletions lib/llm/src/preprocessor/prompt/template/oai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,9 @@ fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
Some(Value::from_serialize(&updated_tools))
}

fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
// If messages[content] is provided as a list containing ONLY text parts,
// concatenate them into a string to match chat template expectations.
// Mixed content types are left for chat templates to handle.
fn may_be_fix_msg_content(messages: serde_json::Value, preserve_arrays: bool) -> Value {
// preserve_arrays=true: strings → arrays (multimodal)
// preserve_arrays=false: text-only arrays → strings (standard)

let Some(arr) = messages.as_array() else {
return Value::from_serialize(&messages);
Expand All @@ -86,7 +85,20 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
.iter()
.map(|msg| {
match msg.get("content") {
Some(serde_json::Value::Array(content_array)) => {
// Case 1: String to Array (for multimodal templates)
Some(serde_json::Value::String(text)) if preserve_arrays => {
let mut modified_msg = msg.clone();
if let Some(msg_object) = modified_msg.as_object_mut() {
let content_array = serde_json::json!([{
"type": "text",
"text": text
}]);
msg_object.insert("content".to_string(), content_array);
}
modified_msg
}
// Case 2: Array to String (for standard templates)
Some(serde_json::Value::Array(content_array)) if !preserve_arrays => {
let is_text_only_array = !content_array.is_empty()
&& content_array.iter().all(|part| {
part.get("type")
Expand Down Expand Up @@ -114,7 +126,7 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
msg.clone() // Mixed content or non-text only
}
}
_ => msg.clone(), // String content or missing content - return unchanged
_ => msg.clone(), // No conversion needed
}
})
.collect();
Expand Down Expand Up @@ -159,19 +171,7 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {

fn messages(&self) -> Value {
let messages_json = serde_json::to_value(&self.inner.messages).unwrap();

let needs_fixing = if let Some(arr) = messages_json.as_array() {
arr.iter()
.any(|msg| msg.get("content").and_then(|c| c.as_array()).is_some())
} else {
false
};

if needs_fixing {
may_be_fix_msg_content(messages_json)
} else {
Value::from_serialize(&messages_json)
}
Value::from_serialize(&messages_json)
}

fn tools(&self) -> Option<Value> {
Expand Down Expand Up @@ -301,6 +301,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
let messages_canonical = req.messages();
let mut messages_for_template: serde_json::Value =
serde_json::to_value(&messages_canonical).unwrap();

messages_for_template = serde_json::to_value(may_be_fix_msg_content(
messages_for_template,
self.requires_content_arrays,
))
.unwrap();

normalize_tool_arguments_in_messages(&mut messages_for_template);

let ctx = context! {
Expand Down Expand Up @@ -457,7 +464,10 @@ mod tests {
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Test array → string normalization (preserve_arrays=false for standard templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Verify: text-only array is concatenated into a single string
assert_eq!(
Expand Down Expand Up @@ -500,7 +510,10 @@ mod tests {
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Test array → string normalization (preserve_arrays=false for standard templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Verify: System message with string content remains unchanged
assert_eq!(
Expand Down Expand Up @@ -541,7 +554,10 @@ mod tests {
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Empty arrays should be preserved regardless of preserve_arrays setting
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Verify: Empty arrays are preserved as-is
assert!(messages[0]["content"].is_array());
Expand All @@ -562,7 +578,10 @@ mod tests {
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Test with preserve_arrays=false (standard templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Verify: String content is not modified
assert_eq!(
Expand All @@ -589,7 +608,10 @@ mod tests {
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Mixed content should be preserved regardless of preserve_arrays setting
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Verify: Mixed content types are preserved as array for template handling
assert!(messages[0]["content"].is_array());
Expand Down Expand Up @@ -617,7 +639,10 @@ mod tests {
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Non-text arrays should be preserved regardless of preserve_arrays setting
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Verify: Non-text content arrays are preserved for template handling
assert!(messages[0]["content"].is_array());
Expand Down Expand Up @@ -713,7 +738,8 @@ NORMAL MODE
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Mixed types should preserve array structure
assert!(messages[0]["content"].is_array());
Expand All @@ -735,7 +761,8 @@ NORMAL MODE
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Unknown types mixed with text should preserve array
assert!(messages[0]["content"].is_array());
Expand Down Expand Up @@ -873,11 +900,15 @@ NORMAL MODE
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let mut messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Apply content normalization with preserve_arrays=false (standard templates)
let mut messages =
serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

normalize_tool_arguments_in_messages(&mut messages);

// Multimodal content preserved as array
// Multimodal content preserved as array (mixed types not flattened)
assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);

Expand All @@ -889,6 +920,63 @@ NORMAL MODE
);
}

/// Tests string → array normalization for multimodal templates
#[test]
fn test_may_be_fix_msg_content_string_to_array() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Test with preserve_arrays=true (multimodal templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, true)).unwrap();

// Verify: String is converted to array format
assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 1);
assert_eq!(content_array[0]["type"], "text");
assert_eq!(content_array[0]["text"], "Hello, how are you?");
}

/// Tests that arrays are preserved when preserve_arrays=true
#[test]
fn test_may_be_fix_msg_content_array_preserved_with_multimodal() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "part 1"},
{"type": "text", "text": "part 2"}
]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Test with preserve_arrays=true (multimodal templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, true)).unwrap();

// Verify: Array is preserved as-is
assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 2);
assert_eq!(content_array[0]["text"], "part 1");
assert_eq!(content_array[1]["text"], "part 2");
}

fn user() -> Msg {
Msg::User(Default::default())
}
Expand Down
36 changes: 36 additions & 0 deletions tests/serve/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,42 @@ class VLLMConfig(EngineConfig):
),
],
),
"multimodal_agg_llava": VLLMConfig(
name="multimodal_agg_llava",
directory=vllm_dir,
script_name="agg_multimodal.sh",
marks=[
pytest.mark.gpu_2,
# https://github.com/ai-dynamo/dynamo/issues/4501
pytest.mark.xfail(strict=False),
],
model="llava-hf/llava-1.5-7b-hf",
script_args=["--model", "llava-hf/llava-1.5-7b-hf"],
delayed_start=0,
timeout=360,
request_payloads=[
# HTTP URL test
chat_payload(
[
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
},
},
],
repeat_count=1,
expected_response=["bus"],
temperature=0.0,
),
# String content test - verifies string → array conversion for multimodal templates
chat_payload_default(
repeat_count=1,
expected_response=[], # Just validate no error
),
],
),
# TODO: Update this test case when we have video multimodal support in vllm official components
"multimodal_video_agg": VLLMConfig(
name="multimodal_video_agg",
Expand Down
Loading