diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index f937e776..e660cc74 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -123,7 +123,8 @@ mod tests { use crate::infer::chat_template::{raise_exception, strftime_now}; use crate::infer::ChatTemplate; use crate::{ - ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool, + ChatTemplateInputs, Message, MessageBody, MessageContent, TextMessage, + TokenizerConfigToken, Tool, }; use chrono::Local; use minijinja::Environment; @@ -1160,24 +1161,27 @@ TOOL CALL ID: 0 Message { name: None, role: "user".to_string(), - content: Some(MessageContent::SingleText( - "I'd like to show off how chat templating works!".to_string(), - )), - tool_calls: None, + body: MessageBody::Content { + content: MessageContent::SingleText( + "I'd like to show off how chat templating works!".to_string(), + ), + }, }, Message { name: None, role: "assistant".to_string(), - content: Some(MessageContent::SingleText( - "Great! How can I help you today?".to_string(), - )), - tool_calls: None, + body: MessageBody::Content { + content: MessageContent::SingleText( + "Great! How can I help you today?".to_string(), + ), + }, }, Message { name: None, role: "user".to_string(), - content: Some(MessageContent::SingleText("Just testing".to_string())), - tool_calls: None, + body: MessageBody::Content { + content: MessageContent::SingleText("Just testing".to_string()), + }, }, ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); @@ -1201,19 +1205,21 @@ TOOL CALL ID: 0 Message { name: None, role: "system".to_string(), - content: Some(MessageContent::SingleText( - "Youre a helpful assistant! Answer the users question best you can." - .to_string(), - )), - tool_calls: None, + body: MessageBody::Content { + content: MessageContent::SingleText( + "Youre a helpful assistant! Answer the users question best you can." + .to_string(), + ), + }, }, Message { name: None, role: "user".to_string(), - content: Some(MessageContent::SingleText( - "What is the weather like in Brooklyn, New York?".to_string(), - )), - tool_calls: None, + body: MessageBody::Content { + content: MessageContent::SingleText( + "What is the weather like in Brooklyn, New York?".to_string(), + ), + }, }, ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); diff --git a/router/src/lib.rs b/router/src/lib.rs index 94c7a48d..e8c875a8 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1596,12 +1596,11 @@ mod tests { assert_eq!( request.messages[0], Message { - role: "user".to_string(), - content: Some(MessageContent::SingleText( - "What is Deep Learning?".to_string() - )), name: None, - tool_calls: None + role: "user".to_string(), + body: MessageBody::Content { + content: MessageContent::SingleText("What is Deep Learning?".to_string()) + }, } ); } @@ -1651,14 +1650,16 @@ mod tests { assert_eq!( request.messages[0], - Message{ - role: "user".to_string(), - content: Some(MessageContent::MultipleChunks(vec![ - MessageChunk::Text { text: "Whats in this image?".to_string() }, - MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }}, - ])), + Message { name: None, - tool_calls: None + role: "user".to_string(), + + body: MessageBody::Content { + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }}, + ]), + }, } ); } @@ -1666,13 +1667,14 @@ mod tests { #[test] fn text_message_convert() { let message = Message{ + name: None, role: "user".to_string(), - content: Some(MessageContent::MultipleChunks(vec![ - MessageChunk::Text { text: "Whats in this image?".to_string() }, - MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } } - ])), - name: None, - tool_calls: None + body: MessageBody::Content { + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } } + ]), + } }; let textmsg: TextMessage = message.into(); assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)"); diff --git a/router/src/vertex.rs b/router/src/vertex.rs index 5a4a3876..38695532 100644 --- a/router/src/vertex.rs +++ b/router/src/vertex.rs @@ -147,7 +147,7 @@ pub(crate) async fn vertex_compatibility( #[cfg(test)] mod tests { use super::*; - use crate::{Message, MessageContent}; + use crate::{Message, MessageBody, MessageContent}; #[test] fn vertex_deserialization() { @@ -169,12 +169,13 @@ mod tests { VertexRequest { instances: vec![VertexInstance::Chat(ChatRequest { messages: vec![Message { - role: "user".to_string(), - content: Some(MessageContent::SingleText( - "What's Deep Learning?".to_string() - )), name: None, - ..Default::default() + role: "user".to_string(), + body: MessageBody::Content { + content: MessageContent::SingleText( + "What's Deep Learning?".to_string() + ) + }, },], max_tokens: Some(128), top_p: Some(0.95),