mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
feat: add test and serialize tool messages
This commit is contained in:
parent
f5e1a16582
commit
ac50b14afb
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I'm an AI and do not have access to real-time data. However, based on location information (Paris) I can provide general information. \n\nThe temperature in Paris varies widely throughout the year. In the summer (June to August), the average high temperature is around 23°C (73°F), while in the winter (December to February), the average low temperature is around -1°C (30°F). \n\nTo get the current weather in Paris, I recommend checking a weather website or",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1739903191,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.1.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 100,
|
||||
"prompt_tokens": 103,
|
||||
"total_tokens": 203
|
||||
}
|
||||
}
|
@ -468,3 +468,42 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
||||
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
|
||||
)
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_tool_reply_response(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=43,
|
||||
messages=[
|
||||
{"role": "user", "content": "What's the weather like in Paris today?"},
|
||||
{
|
||||
"content": "",
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "0",
|
||||
"function": {
|
||||
"arguments": '{"longitude": 2.2945, "latitude": 48.8567}',
|
||||
"name": "get_weather",
|
||||
"description": None,
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "0", "content": "6.7"},
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert responses.choices[0].message.tool_calls is None
|
||||
assert (
|
||||
responses.choices[0].message.content
|
||||
== "I'm an AI and do not have access to real-time data. However, based on location information (Paris) I can provide general information. \n\nThe temperature in Paris varies widely throughout the year. In the summer (June to August), the average high temperature is around 23°C (73°F), while in the winter (December to February), the average low temperature is around -1°C (30°F). \n\nTo get the current weather in Paris, I recommend checking a weather website or"
|
||||
)
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
@ -73,10 +73,8 @@ impl ChatTemplate {
|
||||
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||
format!("\n---\n{}", tool_prompt)
|
||||
};
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some(content) = last_message.content.as_mut() {
|
||||
content.push(MessageChunk::Text { text });
|
||||
}
|
||||
if let Some(content) = messages.last_mut().and_then(|msg| msg.content.as_mut()) {
|
||||
content.push(MessageChunk::Text { text })
|
||||
}
|
||||
Some(tools)
|
||||
}
|
||||
|
@ -1222,22 +1222,24 @@ pub struct TextMessage {
|
||||
|
||||
impl From<Message> for TextMessage {
|
||||
fn from(value: Message) -> Self {
|
||||
let content = value
|
||||
.tool_calls
|
||||
.map(|calls| serde_json::to_string(&calls).unwrap_or_default())
|
||||
.map(MessageContent::SingleText)
|
||||
.or(value.content)
|
||||
.unwrap_or_else(|| MessageContent::SingleText(String::new()));
|
||||
TextMessage {
|
||||
role: value.role,
|
||||
content: match value.content {
|
||||
// If content is Some(MessageContent), handle it accordingly
|
||||
Some(MessageContent::SingleText(text)) => text,
|
||||
Some(MessageContent::MultipleChunks(chunks)) => {
|
||||
chunks.into_iter()
|
||||
.map(|chunk| match chunk {
|
||||
MessageChunk::Text { text } => text,
|
||||
MessageChunk::ImageUrl { image_url } => format!("", image_url.url),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("")
|
||||
}
|
||||
// If content is None, use an empty string or a default message
|
||||
None => String::new(), // or you could use "No content" or another placeholder
|
||||
content: match content {
|
||||
MessageContent::SingleText(text) => text,
|
||||
MessageContent::MultipleChunks(chunks) => chunks
|
||||
.into_iter()
|
||||
.map(|chunk| match chunk {
|
||||
MessageChunk::Text { text } => text,
|
||||
MessageChunk::ImageUrl { image_url } => format!("", image_url.url),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(""),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user