mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
feat: add test and serialize tool messages
This commit is contained in:
parent
f5e1a16582
commit
ac50b14afb
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "I'm an AI and do not have access to real-time data. However, based on location information (Paris) I can provide general information. \n\nThe temperature in Paris varies widely throughout the year. In the summer (June to August), the average high temperature is around 23°C (73°F), while in the winter (December to February), the average low temperature is around -1°C (30°F). \n\nTo get the current weather in Paris, I recommend checking a weather website or",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1739903191,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 100,
|
||||||
|
"prompt_tokens": 103,
|
||||||
|
"total_tokens": 203
|
||||||
|
}
|
||||||
|
}
|
@ -468,3 +468,42 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
|||||||
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
|
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
|
||||||
)
|
)
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_tool_reply_response(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=43,
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris today?"},
|
||||||
|
{
|
||||||
|
"content": "",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "0",
|
||||||
|
"function": {
|
||||||
|
"arguments": '{"longitude": 2.2945, "latitude": 48.8567}',
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": None,
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "0", "content": "6.7"},
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert responses.choices[0].message.tool_calls is None
|
||||||
|
assert (
|
||||||
|
responses.choices[0].message.content
|
||||||
|
== "I'm an AI and do not have access to real-time data. However, based on location information (Paris) I can provide general information. \n\nThe temperature in Paris varies widely throughout the year. In the summer (June to August), the average high temperature is around 23°C (73°F), while in the winter (December to February), the average low temperature is around -1°C (30°F). \n\nTo get the current weather in Paris, I recommend checking a weather website or"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
||||||
|
@ -73,10 +73,8 @@ impl ChatTemplate {
|
|||||||
// if the `tools` variable is used in the template, we just append the tool_prompt
|
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||||
format!("\n---\n{}", tool_prompt)
|
format!("\n---\n{}", tool_prompt)
|
||||||
};
|
};
|
||||||
if let Some(last_message) = messages.last_mut() {
|
if let Some(content) = messages.last_mut().and_then(|msg| msg.content.as_mut()) {
|
||||||
if let Some(content) = last_message.content.as_mut() {
|
content.push(MessageChunk::Text { text })
|
||||||
content.push(MessageChunk::Text { text });
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Some(tools)
|
Some(tools)
|
||||||
}
|
}
|
||||||
|
@ -1222,22 +1222,24 @@ pub struct TextMessage {
|
|||||||
|
|
||||||
impl From<Message> for TextMessage {
|
impl From<Message> for TextMessage {
|
||||||
fn from(value: Message) -> Self {
|
fn from(value: Message) -> Self {
|
||||||
|
let content = value
|
||||||
|
.tool_calls
|
||||||
|
.map(|calls| serde_json::to_string(&calls).unwrap_or_default())
|
||||||
|
.map(MessageContent::SingleText)
|
||||||
|
.or(value.content)
|
||||||
|
.unwrap_or_else(|| MessageContent::SingleText(String::new()));
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: value.role,
|
role: value.role,
|
||||||
content: match value.content {
|
content: match content {
|
||||||
// If content is Some(MessageContent), handle it accordingly
|
MessageContent::SingleText(text) => text,
|
||||||
Some(MessageContent::SingleText(text)) => text,
|
MessageContent::MultipleChunks(chunks) => chunks
|
||||||
Some(MessageContent::MultipleChunks(chunks)) => {
|
.into_iter()
|
||||||
chunks.into_iter()
|
|
||||||
.map(|chunk| match chunk {
|
.map(|chunk| match chunk {
|
||||||
MessageChunk::Text { text } => text,
|
MessageChunk::Text { text } => text,
|
||||||
MessageChunk::ImageUrl { image_url } => format!("", image_url.url),
|
MessageChunk::ImageUrl { image_url } => format!("", image_url.url),
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join("")
|
.join(""),
|
||||||
}
|
|
||||||
// If content is None, use an empty string or a default message
|
|
||||||
None => String::new(), // or you could use "No content" or another placeholder
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user