Skip to content
Prev Previous commit
Next Next commit
Handle chat templates that expect message[content] be a string
Signed-off-by: Krishnan Prashanth <[email protected]>
  • Loading branch information
KrishnanPrash committed Nov 17, 2025
commit d5eaec9c00087f9512dca46a040ebc582bc7c5a5
1 change: 1 addition & 0 deletions lib/llm/src/preprocessor/prompt/template.rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be an indepdendent fix targeting main instead of keivenchang/MDC-fix-on-main-nvbugs5662072? Or does it need to target Keiven's branch?

Copy link
Contributor Author

@KrishnanPrash KrishnanPrash Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we merge Keiven's PR (without this one), special models (with custom block tags) would still be usable with dynamo+vLLM, but could fail if their inference request's are malformed (msg[content] is a string, but the model chat template wants a list). I will leave it up to you on what branch this fix should target. I guess it's more just a question of model support.

Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ struct HfTokenizerConfigJsonFormatter {
config: ChatTemplate,
mixins: Arc<ContextMixins>,
supports_add_generation_prompt: bool,
requires_content_arrays: bool,
}

// /// OpenAI Standard Prompt Formatter
Expand Down
47 changes: 47 additions & 0 deletions lib/llm/src/preprocessor/prompt/template/formatters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,44 @@ fn replace_non_standard_blocks(template: &str) -> String {
result
}

/// Detects whether a chat template requires message content as arrays (multimodal)
/// or accepts simple strings (standard text-only templates).
///
/// This function test-renders the template with both formats:
/// - Array format: `[{"type": "text", "text": "X"}]`
/// - String format: `"X"`
///
/// If the array format works but string format doesn't produce output,
/// the template requires arrays (e.g., llava, Qwen-VL multimodal templates).
fn detect_content_array_usage(env: &Environment) -> bool {
use minijinja::context;
use serde_json::json;

// Test with array format
let test_array = context! {
messages => json!([{"role": "user", "content": [{"type": "text", "text": "X"}]}]),
add_generation_prompt => false,
};

// Test with string format
let test_string = context! {
messages => json!([{"role": "user", "content": "X"}]),
add_generation_prompt => false,
};

let out_array = env
.get_template("default")
.and_then(|t| t.render(&test_array))
.unwrap_or_default();
let out_string = env
.get_template("default")
.and_then(|t| t.render(&test_string))
.unwrap_or_default();

// If array works but string doesn't, template requires arrays
out_array.contains("X") && !out_string.contains("X")
}

impl JinjaEnvironment {
fn env(self) -> Environment<'static> {
self.env
Expand Down Expand Up @@ -165,11 +203,20 @@ impl HfTokenizerConfigJsonFormatter {
}
}

// Detect at model load time whether this template requires content arrays
let requires_content_arrays = detect_content_array_usage(&env);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does vllm-serve, sglang, trtllm-serve, etc. solve this problem?

Copy link
Contributor Author

@KrishnanPrash KrishnanPrash Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From previous investigations, vLLM's pre-processor does something like:

  • Feeds the chat template to Jinja2, which internally parses and generates an AST.
  • Relies on internal representations of Jinja2 to check if a for loop is done over messages[content].
  • If yes, leave messages[content] as an array.
  • If no, flatten messages[content] to a string. (Similar to what we do in feat: Convert message[content] from list to string. #3067)

And from my limited investigation into MiniJinja it does not expose anything similar to that.


tracing::info!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

info is too noisy for this, maybe debug or trace instead if we need this log

"Template analysis: requires_content_arrays = {}",
requires_content_arrays
);

Ok(HfTokenizerConfigJsonFormatter {
env,
config,
mixins: Arc::new(mixins),
supports_add_generation_prompt: supports_add_generation_prompt.unwrap_or(false),
requires_content_arrays,
})
}
}
Expand Down
148 changes: 120 additions & 28 deletions lib/llm/src/preprocessor/prompt/template/oai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
Some(Value::from_serialize(&updated_tools))
}

fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
// If messages[content] is provided as a list containing ONLY text parts,
// concatenate them into a string to match chat template expectations.
// Mixed content types are left for chat templates to handle.
fn may_be_fix_msg_content(messages: serde_json::Value, preserve_arrays: bool) -> Value {
// Bidirectional normalization for message content format:
// - If preserve_arrays=true (multimodal templates): Convert strings → arrays
// - If preserve_arrays=false (standard templates): Flatten text-only arrays → strings
// - Mixed content types are always preserved as-is

let Some(arr) = messages.as_array() else {
return Value::from_serialize(&messages);
Expand All @@ -86,7 +87,20 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
.iter()
.map(|msg| {
match msg.get("content") {
Some(serde_json::Value::Array(content_array)) => {
// Case 1: String → Array (for multimodal templates)
Some(serde_json::Value::String(text)) if preserve_arrays => {
let mut modified_msg = msg.clone();
if let Some(msg_object) = modified_msg.as_object_mut() {
let content_array = serde_json::json!([{
"type": "text",
"text": text
}]);
msg_object.insert("content".to_string(), content_array);
}
modified_msg
}
// Case 2: Array → String (for standard templates)
Some(serde_json::Value::Array(content_array)) if !preserve_arrays => {
let is_text_only_array = !content_array.is_empty()
&& content_array.iter().all(|part| {
part.get("type")
Expand Down Expand Up @@ -159,19 +173,7 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {

fn messages(&self) -> Value {
let messages_json = serde_json::to_value(&self.inner.messages).unwrap();

let needs_fixing = if let Some(arr) = messages_json.as_array() {
arr.iter()
.any(|msg| msg.get("content").and_then(|c| c.as_array()).is_some())
} else {
false
};

if needs_fixing {
may_be_fix_msg_content(messages_json)
} else {
Value::from_serialize(&messages_json)
}
Value::from_serialize(&messages_json)
}

fn tools(&self) -> Option<Value> {
Expand Down Expand Up @@ -301,6 +303,15 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
let messages_canonical = req.messages();
let mut messages_for_template: serde_json::Value =
serde_json::to_value(&messages_canonical).unwrap();

// Apply bidirectional content normalization based on template requirements
let preserve_arrays = self.requires_content_arrays;
messages_for_template = serde_json::to_value(may_be_fix_msg_content(
messages_for_template,
preserve_arrays,
))
.unwrap();

normalize_tool_arguments_in_messages(&mut messages_for_template);

let ctx = context! {
Expand Down Expand Up @@ -457,7 +468,10 @@ mod tests {
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Test array → string normalization (preserve_arrays=false for standard templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Verify: text-only array is concatenated into a single string
assert_eq!(
Expand Down Expand Up @@ -500,7 +514,10 @@ mod tests {
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Test array → string normalization (preserve_arrays=false for standard templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Verify: System message with string content remains unchanged
assert_eq!(
Expand Down Expand Up @@ -541,7 +558,10 @@ mod tests {
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Empty arrays should be preserved regardless of preserve_arrays setting
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Verify: Empty arrays are preserved as-is
assert!(messages[0]["content"].is_array());
Expand All @@ -562,7 +582,10 @@ mod tests {
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Test with preserve_arrays=false (standard templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Verify: String content is not modified
assert_eq!(
Expand All @@ -589,7 +612,10 @@ mod tests {
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Mixed content should be preserved regardless of preserve_arrays setting
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Verify: Mixed content types are preserved as array for template handling
assert!(messages[0]["content"].is_array());
Expand Down Expand Up @@ -617,7 +643,10 @@ mod tests {
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Non-text arrays should be preserved regardless of preserve_arrays setting
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Verify: Non-text content arrays are preserved for template handling
assert!(messages[0]["content"].is_array());
Expand Down Expand Up @@ -713,7 +742,8 @@ NORMAL MODE
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Mixed types should preserve array structure
assert!(messages[0]["content"].is_array());
Expand All @@ -735,7 +765,8 @@ NORMAL MODE
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

// Unknown types mixed with text should preserve array
assert!(messages[0]["content"].is_array());
Expand Down Expand Up @@ -873,11 +904,15 @@ NORMAL MODE
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let mut messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Apply content normalization with preserve_arrays=false (standard templates)
let mut messages =
serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();

normalize_tool_arguments_in_messages(&mut messages);

// Multimodal content preserved as array
// Multimodal content preserved as array (mixed types not flattened)
assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);

Expand All @@ -889,6 +924,63 @@ NORMAL MODE
);
}

/// Tests string → array normalization for multimodal templates
#[test]
fn test_may_be_fix_msg_content_string_to_array() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Test with preserve_arrays=true (multimodal templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, true)).unwrap();

// Verify: String is converted to array format
assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 1);
assert_eq!(content_array[0]["type"], "text");
assert_eq!(content_array[0]["text"], "Hello, how are you?");
}

/// Tests that arrays are preserved when preserve_arrays=true
#[test]
fn test_may_be_fix_msg_content_array_preserved_with_multimodal() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "part 1"},
{"type": "text", "text": "part 2"}
]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();

// Test with preserve_arrays=true (multimodal templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, true)).unwrap();

// Verify: Array is preserved as-is
assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 2);
assert_eq!(content_array[0]["text"], "part 1");
assert_eq!(content_array[1]["text"], "part 2");
}

fn user() -> Msg {
Msg::User(Default::default())
}
Expand Down
Loading