diff --git a/lib/llm/src/preprocessor/prompt/template/oai.rs b/lib/llm/src/preprocessor/prompt/template/oai.rs index abe8759932..cc3f571e0d 100644 --- a/lib/llm/src/preprocessor/prompt/template/oai.rs +++ b/lib/llm/src/preprocessor/prompt/template/oai.rs @@ -72,13 +72,75 @@ fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option { Some(Value::from_serialize(&updated_tools)) } +fn may_be_fix_msg_content(messages: serde_json::Value) -> Value { + // If messages[content] is provided as a list containing ONLY text parts, + // concatenate them into a string to match chat template expectations. + // Mixed content types are left for chat templates to handle. + + let Some(arr) = messages.as_array() else { + return Value::from_serialize(&messages); + }; + + let updated_messages: Vec<_> = arr + .iter() + .map(|msg| { + match msg.get("content") { + Some(serde_json::Value::Array(content_array)) => { + let is_text_only_array = !content_array.is_empty() + && content_array.iter().all(|part| { + part.get("type") + .and_then(|type_field| type_field.as_str()) + .map(|type_str| type_str == "text") + .unwrap_or(false) + }); + + if is_text_only_array { + let mut modified_msg = msg.clone(); + if let Some(msg_object) = modified_msg.as_object_mut() { + let text_parts: Vec<&str> = content_array + .iter() + .filter_map(|part| part.get("text")?.as_str()) + .collect(); + let concatenated_text = text_parts.join("\n"); + + msg_object.insert( + "content".to_string(), + serde_json::Value::String(concatenated_text), + ); + } + modified_msg // Concatenated string content + } else { + msg.clone() // Mixed content or non-text only + } + } + _ => msg.clone(), // String content or missing content - return unchanged + } + }) + .collect(); + + Value::from_serialize(&updated_messages) +} + impl OAIChatLikeRequest for NvCreateChatCompletionRequest { fn model(&self) -> String { self.inner.model.clone() } fn messages(&self) -> Value { - Value::from_serialize(&self.inner.messages) + let messages_json = serde_json::to_value(&self.inner.messages).unwrap(); + + let needs_fixing = if let Some(arr) = messages_json.as_array() { + arr.iter() + .any(|msg| msg.get("content").and_then(|c| c.as_array()).is_some()) + } else { + false + }; + + if needs_fixing { + may_be_fix_msg_content(messages_json) + } else { + Value::from_serialize(&messages_json) + } } fn tools(&self) -> Option { @@ -335,4 +397,241 @@ mod tests { ); assert_eq!(tools[0]["function"]["parameters"]["type"], "object"); } + + /// Tests that content arrays (containing only text parts) are correctly concatenated. + #[test] + fn test_may_be_fix_msg_content_user_multipart() { + let json_str = r#"{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "part 1"}, + {"type": "text", "text": "part 2"} + ] + } + ] + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + let messages = serde_json::to_value(request.messages()).unwrap(); + + // Verify: text-only array is concatenated into a single string + assert_eq!( + messages[0]["content"], + serde_json::Value::String("part 1\npart 2".to_string()) + ); + } + + /// Tests that the function correctly handles a conversation + /// with multiple roles and mixed message types: + #[test] + fn test_may_be_fix_msg_content_mixed_messages() { + let json_str = r#"{ + "model": "gpt-4o", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": "World"} + ] + }, + { + "role": "assistant", + "content": "Hi there!" + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Another"}, + {"type": "text", "text": "multi-part"}, + {"type": "text", "text": "message"} + ] + } + ] + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + let messages = serde_json::to_value(request.messages()).unwrap(); + + // Verify: System message with string content remains unchanged + assert_eq!( + messages[0]["content"], + serde_json::Value::String("You are a helpful assistant".to_string()) + ); + + // Verify: User message with text-only array is concatenated + assert_eq!( + messages[1]["content"], + serde_json::Value::String("Hello\nWorld".to_string()) + ); + + // Verify: Assistant message with string content remains unchanged + assert_eq!( + messages[2]["content"], + serde_json::Value::String("Hi there!".to_string()) + ); + + // Verify: Second user message with text-only array is concatenated + assert_eq!( + messages[3]["content"], + serde_json::Value::String("Another\nmulti-part\nmessage".to_string()) + ); + } + + /// Tests that empty content arrays remain unchanged. + #[test] + fn test_may_be_fix_msg_content_empty_array() { + let json_str = r#"{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [] + } + ] + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + let messages = serde_json::to_value(request.messages()).unwrap(); + + // Verify: Empty arrays are preserved as-is + assert!(messages[0]["content"].is_array()); + assert_eq!(messages[0]["content"].as_array().unwrap().len(), 0); + } + + /// Tests that messages with simple string content remain unchanged. + #[test] + fn test_may_be_fix_msg_content_single_text() { + let json_str = r#"{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "Simple text message" + } + ] + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + let messages = serde_json::to_value(request.messages()).unwrap(); + + // Verify: String content is not modified + assert_eq!( + messages[0]["content"], + serde_json::Value::String("Simple text message".to_string()) + ); + } + + /// Tests that content arrays with mixed types (text + non-text) remain as arrays. + #[test] + fn test_may_be_fix_msg_content_mixed_types() { + let json_str = r#"{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Check this image:"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}, + {"type": "text", "text": "What do you see?"} + ] + } + ] + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + let messages = serde_json::to_value(request.messages()).unwrap(); + + // Verify: Mixed content types are preserved as array for template handling + assert!(messages[0]["content"].is_array()); + let content_array = messages[0]["content"].as_array().unwrap(); + assert_eq!(content_array.len(), 3); + assert_eq!(content_array[0]["type"], "text"); + assert_eq!(content_array[1]["type"], "image_url"); + assert_eq!(content_array[2]["type"], "text"); + } + + /// Tests that content arrays containing only non-text types remain as arrays. + #[test] + fn test_may_be_fix_msg_content_non_text_only() { + let json_str = r#"{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}}, + {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}} + ] + } + ] + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + let messages = serde_json::to_value(request.messages()).unwrap(); + + // Verify: Non-text content arrays are preserved for template handling + assert!(messages[0]["content"].is_array()); + let content_array = messages[0]["content"].as_array().unwrap(); + assert_eq!(content_array.len(), 2); + assert_eq!(content_array[0]["type"], "image_url"); + assert_eq!(content_array[1]["type"], "image_url"); + } + + /// Tests mixed content type scenarios. + #[test] + fn test_may_be_fix_msg_content_multiple_content_types() { + // Scenario 1: Multiple different content types (text + image + audio) + let json_str = r#"{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Listen to this:"}, + {"type": "audio_url", "audio_url": {"url": "https://example.com/audio.mp3"}}, + {"type": "text", "text": "And look at:"}, + {"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}}, + {"type": "text", "text": "What do you think?"} + ] + } + ] + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + let messages = serde_json::to_value(request.messages()).unwrap(); + + // Mixed types should preserve array structure + assert!(messages[0]["content"].is_array()); + assert_eq!(messages[0]["content"].as_array().unwrap().len(), 5); + + // Scenario 2: Unknown/future content types mixed with text + let json_str = r#"{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Check this:"}, + {"type": "video_url", "video_url": {"url": "https://example.com/vid.mp4"}}, + {"type": "text", "text": "Interesting?"} + ] + } + ] + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + let messages = serde_json::to_value(request.messages()).unwrap(); + + // Unknown types mixed with text should preserve array + assert!(messages[0]["content"].is_array()); + assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3); + } }