mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
fix: only send function name on first stream event
This commit is contained in:
parent
68aa6b1af0
commit
538456ba68
@ -20,7 +20,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1739799458,
|
||||
"created": 1739910558,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
|
@ -20,7 +20,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1739797595,
|
||||
"created": 1739910826,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
|
@ -21,7 +21,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1739456930,
|
||||
"created": 1739910816,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
|
@ -21,7 +21,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1739367874,
|
||||
"created": 1739910803,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
|
@ -101,7 +101,9 @@ async def test_openai_llama_tools(openai_llama_tools, response_snapshot):
|
||||
|
||||
tool_call_string = ""
|
||||
for chunk in chat_completion:
|
||||
tool_call_string += chunk.choices[0].delta.tool_calls[0].function.arguments
|
||||
function_call = chunk.choices[0].delta.tool_calls[0].function
|
||||
if function_call:
|
||||
tool_call_string += function_call.arguments
|
||||
last_chunk = chunk.to_dict()
|
||||
|
||||
assert (
|
||||
|
@ -216,7 +216,7 @@ async def test_flash_llama_grammar_tools_stream(
|
||||
assert response.choices[0].delta.content is None
|
||||
|
||||
assert tool_calls_generated == '{ "location": "Paris, France", "format": "celsius"}'
|
||||
assert count == 16
|
||||
assert count == 17
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@ -360,7 +360,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
|
||||
)
|
||||
last_response = response
|
||||
|
||||
assert count == 23
|
||||
assert count == 24
|
||||
assert (
|
||||
tool_calls_generated
|
||||
== '{ "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}'
|
||||
@ -458,7 +458,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
||||
tool_calls_generated += tool_call["function"]["arguments"]
|
||||
last_response = response
|
||||
|
||||
assert count == 25
|
||||
assert count == 26
|
||||
assert (
|
||||
tool_calls_generated
|
||||
== '{ "location": "San Francisco, CA", "format": "celsius", "num_days": 3}'
|
||||
|
@ -1305,6 +1305,40 @@ pub(crate) async fn chat_completions(
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
};
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
let tool_delta_start = ChatCompletionDelta::Tool(ToolCallDelta {
|
||||
role: "assistant".to_string(),
|
||||
tool_calls: vec![DeltaToolCall {
|
||||
index: 0,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: Some(global_function_name.clone()),
|
||||
arguments: "".to_string(),
|
||||
},
|
||||
}],
|
||||
});
|
||||
let chat_complete =
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk{
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
delta: tool_delta_start,
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
}],
|
||||
usage: None,
|
||||
});
|
||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
InferError::StreamSerializationError(e.to_string()).into()
|
||||
}));
|
||||
buffer.drain(1..); // only keep the first token (opening '{')
|
||||
buffer[0].token.text = buffer[0].token.text.chars().take(1).collect();
|
||||
}
|
||||
@ -1341,7 +1375,7 @@ pub(crate) async fn chat_completions(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(global_function_name.clone()),
|
||||
None,
|
||||
));
|
||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
InferError::StreamSerializationError(e.to_string()).into()
|
||||
@ -1370,7 +1404,7 @@ pub(crate) async fn chat_completions(
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
Some(global_function_name.clone()),
|
||||
None,
|
||||
);
|
||||
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
@ -1394,7 +1428,7 @@ pub(crate) async fn chat_completions(
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
Some(global_function_name.clone()),
|
||||
None,
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
} else {
|
||||
@ -1407,7 +1441,7 @@ pub(crate) async fn chat_completions(
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
Some(global_function_name.clone()),
|
||||
None,
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user