Skip to content
301 changes: 300 additions & 1 deletion lib/llm/src/preprocessor/prompt/template/oai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,75 @@ fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
Some(Value::from_serialize(&updated_tools))
}

fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
// If messages[content] is provided as a list containing ONLY text parts,
// concatenate them into a string to match chat template expectations.
// Mixed content types are left for chat templates to handle.

let Some(arr) = messages.as_array() else {
return Value::from_serialize(&messages);
};

let updated_messages: Vec<_> = arr
.iter()
.map(|msg| {
match msg.get("content") {
Some(serde_json::Value::Array(content_array)) => {
let is_text_only_array = !content_array.is_empty()
&& content_array.iter().all(|part| {
part.get("type")
.and_then(|type_field| type_field.as_str())
.map(|type_str| type_str == "text")
.unwrap_or(false)
});

if is_text_only_array {
let mut modified_msg = msg.clone();
if let Some(msg_object) = modified_msg.as_object_mut() {
let text_parts: Vec<&str> = content_array
.iter()
.filter_map(|part| part.get("text")?.as_str())
.collect();
let concatenated_text = text_parts.join("\n");

msg_object.insert(
"content".to_string(),
serde_json::Value::String(concatenated_text),
);
}
modified_msg // Concatenated string content
} else {
msg.clone() // Mixed content or non-text only
}
}
_ => msg.clone(), // String content or missing content - return unchanged
}
})
.collect();

Value::from_serialize(&updated_messages)
}

impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn model(&self) -> String {
self.inner.model.clone()
}

fn messages(&self) -> Value {
Value::from_serialize(&self.inner.messages)
let messages_json = serde_json::to_value(&self.inner.messages).unwrap();

let needs_fixing = if let Some(arr) = messages_json.as_array() {
arr.iter()
.any(|msg| msg.get("content").and_then(|c| c.as_array()).is_some())
} else {
false
};

if needs_fixing {
may_be_fix_msg_content(messages_json)
} else {
Value::from_serialize(&messages_json)
}
}

fn tools(&self) -> Option<Value> {
Expand Down Expand Up @@ -335,4 +397,241 @@ mod tests {
);
assert_eq!(tools[0]["function"]["parameters"]["type"], "object");
}

/// Tests that content arrays (containing only text parts) are correctly concatenated.
#[test]
fn test_may_be_fix_msg_content_user_multipart() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "part 1"},
{"type": "text", "text": "part 2"}
]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

// Verify: text-only array is concatenated into a single string
assert_eq!(
messages[0]["content"],
serde_json::Value::String("part 1\npart 2".to_string())
);
}

/// Tests that the function correctly handles a conversation
/// with multiple roles and mixed message types:
#[test]
fn test_may_be_fix_msg_content_mixed_messages() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"}
]
},
{
"role": "assistant",
"content": "Hi there!"
},
{
"role": "user",
"content": [
{"type": "text", "text": "Another"},
{"type": "text", "text": "multi-part"},
{"type": "text", "text": "message"}
]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

// Verify: System message with string content remains unchanged
assert_eq!(
messages[0]["content"],
serde_json::Value::String("You are a helpful assistant".to_string())
);

// Verify: User message with text-only array is concatenated
assert_eq!(
messages[1]["content"],
serde_json::Value::String("Hello\nWorld".to_string())
);

// Verify: Assistant message with string content remains unchanged
assert_eq!(
messages[2]["content"],
serde_json::Value::String("Hi there!".to_string())
);

// Verify: Second user message with text-only array is concatenated
assert_eq!(
messages[3]["content"],
serde_json::Value::String("Another\nmulti-part\nmessage".to_string())
);
}

/// Tests that empty content arrays remain unchanged.
#[test]
fn test_may_be_fix_msg_content_empty_array() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": []
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

// Verify: Empty arrays are preserved as-is
assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 0);
}

/// Tests that messages with simple string content remain unchanged.
#[test]
fn test_may_be_fix_msg_content_single_text() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "Simple text message"
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

// Verify: String content is not modified
assert_eq!(
messages[0]["content"],
serde_json::Value::String("Simple text message".to_string())
);
}

/// Tests that content arrays with mixed types (text + non-text) remain as arrays.
#[test]
fn test_may_be_fix_msg_content_mixed_types() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Check this image:"},
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
{"type": "text", "text": "What do you see?"}
]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

// Verify: Mixed content types are preserved as array for template handling
assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 3);
assert_eq!(content_array[0]["type"], "text");
assert_eq!(content_array[1]["type"], "image_url");
assert_eq!(content_array[2]["type"], "text");
}

/// Tests that content arrays containing only non-text types remain as arrays.
#[test]
fn test_may_be_fix_msg_content_non_text_only() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}},
{"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}}
]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

// Verify: Non-text content arrays are preserved for template handling
assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 2);
assert_eq!(content_array[0]["type"], "image_url");
assert_eq!(content_array[1]["type"], "image_url");
}

/// Tests mixed content type scenarios.
#[test]
fn test_may_be_fix_msg_content_multiple_content_types() {
// Scenario 1: Multiple different content types (text + image + audio)
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Listen to this:"},
{"type": "audio_url", "audio_url": {"url": "https://example.com/audio.mp3"}},
{"type": "text", "text": "And look at:"},
{"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}},
{"type": "text", "text": "What do you think?"}
]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

// Mixed types should preserve array structure
assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 5);

// Scenario 2: Unknown/future content types mixed with text
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Check this:"},
{"type": "video_url", "video_url": {"url": "https://example.com/vid.mp4"}},
{"type": "text", "text": "Interesting?"}
]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

// Unknown types mixed with text should preserve array
assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
}
}
Loading