feat: add test and serialize tool messages

This commit is contained in:
drbh 2025-02-19 00:47:53 +00:00
parent f5e1a16582
commit ac50b14afb
4 changed files with 83 additions and 18 deletions

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "I'm an AI and do not have access to real-time data. However, based on location information (Paris) I can provide general information. \n\nThe temperature in Paris varies widely throughout the year. In the summer (June to August), the average high temperature is around 23°C (73°F), while in the winter (December to February), the average low temperature is around -1°C (30°F). \n\nTo get the current weather in Paris, I recommend checking a weather website or",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1739903191,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.1.1-dev0-native",
"usage": {
"completion_tokens": 100,
"prompt_tokens": 103,
"total_tokens": 203
}
}

View File

@ -468,3 +468,42 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
)
assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_tool_reply_response(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=43,
messages=[
{"role": "user", "content": "What's the weather like in Paris today?"},
{
"content": "",
"role": "assistant",
"tool_calls": [
{
"id": "0",
"function": {
"arguments": '{"longitude": 2.2945, "latitude": 48.8567}',
"name": "get_weather",
"description": None,
},
"type": "function",
}
],
},
{"role": "tool", "tool_call_id": "0", "content": "6.7"},
],
stream=False,
)
assert responses.choices[0].message.tool_calls is None
assert (
responses.choices[0].message.content
== "I'm an AI and do not have access to real-time data. However, based on location information (Paris) I can provide general information. \n\nThe temperature in Paris varies widely throughout the year. In the summer (June to August), the average high temperature is around 23°C (73°F), while in the winter (December to February), the average low temperature is around -1°C (30°F). \n\nTo get the current weather in Paris, I recommend checking a weather website or"
)
assert responses == response_snapshot

View File

@ -73,10 +73,8 @@ impl ChatTemplate {
// if the `tools` variable is used in the template, we just append the tool_prompt
format!("\n---\n{}", tool_prompt)
};
if let Some(last_message) = messages.last_mut() {
if let Some(content) = last_message.content.as_mut() {
content.push(MessageChunk::Text { text });
}
if let Some(content) = messages.last_mut().and_then(|msg| msg.content.as_mut()) {
content.push(MessageChunk::Text { text })
}
Some(tools)
}

View File

@ -1222,22 +1222,24 @@ pub struct TextMessage {
impl From<Message> for TextMessage {
fn from(value: Message) -> Self {
let content = value
.tool_calls
.map(|calls| serde_json::to_string(&calls).unwrap_or_default())
.map(MessageContent::SingleText)
.or(value.content)
.unwrap_or_else(|| MessageContent::SingleText(String::new()));
TextMessage {
role: value.role,
content: match value.content {
// If content is Some(MessageContent), handle it accordingly
Some(MessageContent::SingleText(text)) => text,
Some(MessageContent::MultipleChunks(chunks)) => {
chunks.into_iter()
.map(|chunk| match chunk {
MessageChunk::Text { text } => text,
MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url),
})
.collect::<Vec<_>>()
.join("")
}
// If content is None, use an empty string or a default message
None => String::new(), // or you could use "No content" or another placeholder
content: match content {
MessageContent::SingleText(text) => text,
MessageContent::MultipleChunks(chunks) => chunks
.into_iter()
.map(|chunk| match chunk {
MessageChunk::Text { text } => text,
MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url),
})
.collect::<Vec<_>>()
.join(""),
},
}
}