fix: adjust message types in tests

This commit is contained in:
drbh 2025-02-20 19:38:45 +00:00
parent 4fa8512d99
commit 3770344529
3 changed files with 53 additions and 44 deletions

View File

@ -123,7 +123,8 @@ mod tests {
use crate::infer::chat_template::{raise_exception, strftime_now}; use crate::infer::chat_template::{raise_exception, strftime_now};
use crate::infer::ChatTemplate; use crate::infer::ChatTemplate;
use crate::{ use crate::{
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool, ChatTemplateInputs, Message, MessageBody, MessageContent, TextMessage,
TokenizerConfigToken, Tool,
}; };
use chrono::Local; use chrono::Local;
use minijinja::Environment; use minijinja::Environment;
@ -1160,24 +1161,27 @@ TOOL CALL ID: 0
Message { Message {
name: None, name: None,
role: "user".to_string(), role: "user".to_string(),
content: Some(MessageContent::SingleText( body: MessageBody::Content {
content: MessageContent::SingleText(
"I'd like to show off how chat templating works!".to_string(), "I'd like to show off how chat templating works!".to_string(),
)), ),
tool_calls: None, },
}, },
Message { Message {
name: None, name: None,
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some(MessageContent::SingleText( body: MessageBody::Content {
content: MessageContent::SingleText(
"Great! How can I help you today?".to_string(), "Great! How can I help you today?".to_string(),
)), ),
tool_calls: None, },
}, },
Message { Message {
name: None, name: None,
role: "user".to_string(), role: "user".to_string(),
content: Some(MessageContent::SingleText("Just testing".to_string())), body: MessageBody::Content {
tool_calls: None, content: MessageContent::SingleText("Just testing".to_string()),
},
}, },
]; ];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
@ -1201,19 +1205,21 @@ TOOL CALL ID: 0
Message { Message {
name: None, name: None,
role: "system".to_string(), role: "system".to_string(),
content: Some(MessageContent::SingleText( body: MessageBody::Content {
content: MessageContent::SingleText(
"Youre a helpful assistant! Answer the users question best you can." "Youre a helpful assistant! Answer the users question best you can."
.to_string(), .to_string(),
)), ),
tool_calls: None, },
}, },
Message { Message {
name: None, name: None,
role: "user".to_string(), role: "user".to_string(),
content: Some(MessageContent::SingleText( body: MessageBody::Content {
content: MessageContent::SingleText(
"What is the weather like in Brooklyn, New York?".to_string(), "What is the weather like in Brooklyn, New York?".to_string(),
)), ),
tool_calls: None, },
}, },
]; ];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();

View File

@ -1596,12 +1596,11 @@ mod tests {
assert_eq!( assert_eq!(
request.messages[0], request.messages[0],
Message { Message {
role: "user".to_string(),
content: Some(MessageContent::SingleText(
"What is Deep Learning?".to_string()
)),
name: None, name: None,
tool_calls: None role: "user".to_string(),
body: MessageBody::Content {
content: MessageContent::SingleText("What is Deep Learning?".to_string())
},
} }
); );
} }
@ -1651,14 +1650,16 @@ mod tests {
assert_eq!( assert_eq!(
request.messages[0], request.messages[0],
Message{ Message {
name: None,
role: "user".to_string(), role: "user".to_string(),
content: Some(MessageContent::MultipleChunks(vec![
body: MessageBody::Content {
content: MessageContent::MultipleChunks(vec![
MessageChunk::Text { text: "Whats in this image?".to_string() }, MessageChunk::Text { text: "Whats in this image?".to_string() },
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }}, MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }},
])), ]),
name: None, },
tool_calls: None
} }
); );
} }
@ -1666,13 +1667,14 @@ mod tests {
#[test] #[test]
fn text_message_convert() { fn text_message_convert() {
let message = Message{ let message = Message{
name: None,
role: "user".to_string(), role: "user".to_string(),
content: Some(MessageContent::MultipleChunks(vec![ body: MessageBody::Content {
content: MessageContent::MultipleChunks(vec![
MessageChunk::Text { text: "Whats in this image?".to_string() }, MessageChunk::Text { text: "Whats in this image?".to_string() },
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } } MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }
])), ]),
name: None, }
tool_calls: None
}; };
let textmsg: TextMessage = message.into(); let textmsg: TextMessage = message.into();
assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)"); assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)");

View File

@ -147,7 +147,7 @@ pub(crate) async fn vertex_compatibility(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::{Message, MessageContent}; use crate::{Message, MessageBody, MessageContent};
#[test] #[test]
fn vertex_deserialization() { fn vertex_deserialization() {
@ -169,12 +169,13 @@ mod tests {
VertexRequest { VertexRequest {
instances: vec![VertexInstance::Chat(ChatRequest { instances: vec![VertexInstance::Chat(ChatRequest {
messages: vec![Message { messages: vec![Message {
role: "user".to_string(),
content: Some(MessageContent::SingleText(
"What's Deep Learning?".to_string()
)),
name: None, name: None,
..Default::default() role: "user".to_string(),
body: MessageBody::Content {
content: MessageContent::SingleText(
"What's Deep Learning?".to_string()
)
},
},], },],
max_tokens: Some(128), max_tokens: Some(128),
top_p: Some(0.95), top_p: Some(0.95),