diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_tool_reply_response.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_tool_reply_response.json new file mode 100644 index 00000000..3742f342 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_tool_reply_response.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "I'm an AI and do not have access to real-time data. However, based on location information (Paris) I can provide general information. \n\nThe temperature in Paris varies widely throughout the year. In the summer (June to August), the average high temperature is around 23°C (73°F), while in the winter (December to February), the average low temperature is around -1°C (30°F). \n\nTo get the current weather in Paris, I recommend checking a weather website or", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1739903191, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.1.1-dev0-native", + "usage": { + "completion_tokens": 100, + "prompt_tokens": 103, + "total_tokens": 203 + } +} diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 70c3aff0..c73c1a92 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -468,3 +468,42 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( == '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>' ) assert last_response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_tool_reply_response( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=43, + messages=[ + {"role": "user", "content": "What's the weather like in Paris today?"}, + { + "content": "", + "role": "assistant", + "tool_calls": [ + { + "id": "0", + "function": { + "arguments": '{"longitude": 2.2945, "latitude": 48.8567}', + "name": "get_weather", + "description": None, + }, + "type": "function", + } + ], + }, + {"role": "tool", "tool_call_id": "0", "content": "6.7"}, + ], + stream=False, + ) + + assert responses.choices[0].message.tool_calls is None + assert ( + responses.choices[0].message.content + == "I'm an AI and do not have access to real-time data. However, based on location information (Paris) I can provide general information. \n\nThe temperature in Paris varies widely throughout the year. In the summer (June to August), the average high temperature is around 23°C (73°F), while in the winter (December to February), the average low temperature is around -1°C (30°F). \n\nTo get the current weather in Paris, I recommend checking a weather website or" + ) + + assert responses == response_snapshot diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index b5bd8c1e..1aa0860f 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -73,10 +73,8 @@ impl ChatTemplate { // if the `tools` variable is used in the template, we just append the tool_prompt format!("\n---\n{}", tool_prompt) }; - if let Some(last_message) = messages.last_mut() { - if let Some(content) = last_message.content.as_mut() { - content.push(MessageChunk::Text { text }); - } + if let Some(content) = messages.last_mut().and_then(|msg| msg.content.as_mut()) { + content.push(MessageChunk::Text { text }) } Some(tools) } diff --git a/router/src/lib.rs b/router/src/lib.rs index c135ac2d..dae11a23 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1222,22 +1222,24 @@ pub struct TextMessage { impl From for TextMessage { fn from(value: Message) -> Self { + let content = value + .tool_calls + .map(|calls| serde_json::to_string(&calls).unwrap_or_default()) + .map(MessageContent::SingleText) + .or(value.content) + .unwrap_or_else(|| MessageContent::SingleText(String::new())); TextMessage { role: value.role, - content: match value.content { - // If content is Some(MessageContent), handle it accordingly - Some(MessageContent::SingleText(text)) => text, - Some(MessageContent::MultipleChunks(chunks)) => { - chunks.into_iter() - .map(|chunk| match chunk { - MessageChunk::Text { text } => text, - MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url), - }) - .collect::>() - .join("") - } - // If content is None, use an empty string or a default message - None => String::new(), // or you could use "No content" or another placeholder + content: match content { + MessageContent::SingleText(text) => text, + MessageContent::MultipleChunks(chunks) => chunks + .into_iter() + .map(|chunk| match chunk { + MessageChunk::Text { text } => text, + MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url), + }) + .collect::>() + .join(""), }, } }