Skip to content
237 changes: 236 additions & 1 deletion lib/llm/src/preprocessor/prompt/template/oai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,77 @@ fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
Some(Value::from_serialize(&updated_tools))
}

fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
// If messages[content] is provided as a list containing ONLY text parts,
// concatenate them into a string to match chat template expectations.
// Mixed content types are left for chat templates to handle.
if let Some(arr) = messages.as_array() {
let mut updated_messages = Vec::new();
for msg in arr {
if let Some(content) = msg.get("content") {
if let Some(content_array) = content.as_array() {
let all_text = content_array.iter().all(|part| {
part.get("type")
.and_then(|t| t.as_str())
.map(|t| t == "text")
.unwrap_or(false)
});

if all_text && !content_array.is_empty() {
let mut msg = msg.clone();
let mut text_parts = Vec::new();

for part in content_array {
if let Some(text) = part.get("text")
&& let Some(text_str) = text.as_str()
{
text_parts.push(text_str.to_string());
}
}

if let Some(msg_obj) = msg.as_object_mut() {
msg_obj.insert(
"content".to_string(),
serde_json::Value::String(text_parts.join("\n")),
);
}
updated_messages.push(msg);
} else {
updated_messages.push(msg.clone());
}
} else {
updated_messages.push(msg.clone());
}
} else {
updated_messages.push(msg.clone());
}
}
Value::from_serialize(&updated_messages)
} else {
Value::from_serialize(&messages)
}
}

impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn model(&self) -> String {
self.inner.model.clone()
}

fn messages(&self) -> Value {
Value::from_serialize(&self.inner.messages)
let messages_json = serde_json::to_value(&self.inner.messages).unwrap();

let needs_collapse = if let Some(arr) = messages_json.as_array() {
arr.iter()
.any(|msg| msg.get("content").and_then(|c| c.as_array()).is_some())
} else {
false
};

if needs_collapse {
may_be_fix_msg_content(messages_json)
} else {
Value::from_serialize(&messages_json)
}
}

fn tools(&self) -> Option<Value> {
Expand Down Expand Up @@ -339,4 +403,175 @@ mod tests {
);
assert_eq!(tools[0]["function"]["parameters"]["type"], "object");
}

#[test]
fn test_may_be_fix_msg_content_user_multipart() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "part 1"},
{"type": "text", "text": "part 2"}
]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

assert_eq!(
messages[0]["content"],
serde_json::Value::String("part 1\npart 2".to_string())
);
}

#[test]
fn test_may_be_fix_msg_content_mixed_messages() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"}
]
},
{
"role": "assistant",
"content": "Hi there!"
},
{
"role": "user",
"content": [
{"type": "text", "text": "Another"},
{"type": "text", "text": "multi-part"},
{"type": "text", "text": "message"}
]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

assert_eq!(
messages[0]["content"],
serde_json::Value::String("You are a helpful assistant".to_string())
);

assert_eq!(
messages[1]["content"],
serde_json::Value::String("Hello\nWorld".to_string())
);

assert_eq!(
messages[2]["content"],
serde_json::Value::String("Hi there!".to_string())
);

assert_eq!(
messages[3]["content"],
serde_json::Value::String("Another\nmulti-part\nmessage".to_string())
);
}

#[test]
fn test_may_be_fix_msg_content_empty_array() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": []
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 0);
}

#[test]
fn test_may_be_fix_msg_content_single_text() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "Simple text message"
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

assert_eq!(
messages[0]["content"],
serde_json::Value::String("Simple text message".to_string())
);
}

#[test]
fn test_may_be_fix_msg_content_mixed_types() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Check this image:"},
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
{"type": "text", "text": "What do you see?"}
]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 3);
assert_eq!(content_array[0]["type"], "text");
assert_eq!(content_array[1]["type"], "image_url");
assert_eq!(content_array[2]["type"], "text");
}

#[test]
fn test_may_be_fix_msg_content_non_text_only() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}},
{"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}}
]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 2);
assert_eq!(content_array[0]["type"], "image_url");
assert_eq!(content_array[1]["type"], "image_url");
}
}
Loading