fix: only send function name on first stream event

This commit is contained in:
drbh 2025-02-18 21:13:03 +00:00
parent 68aa6b1af0
commit 538456ba68
7 changed files with 48 additions and 12 deletions

View File

@ -20,7 +20,7 @@
"logprobs": null
}
],
"created": 1739799458,
"created": 1739910558,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -20,7 +20,7 @@
"logprobs": null
}
],
"created": 1739797595,
"created": 1739910826,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -21,7 +21,7 @@
"logprobs": null
}
],
"created": 1739456930,
"created": 1739910816,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -21,7 +21,7 @@
"logprobs": null
}
],
"created": 1739367874,
"created": 1739910803,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -101,7 +101,9 @@ async def test_openai_llama_tools(openai_llama_tools, response_snapshot):
tool_call_string = ""
for chunk in chat_completion:
tool_call_string += chunk.choices[0].delta.tool_calls[0].function.arguments
function_call = chunk.choices[0].delta.tool_calls[0].function
if function_call:
tool_call_string += function_call.arguments
last_chunk = chunk.to_dict()
assert (

View File

@ -216,7 +216,7 @@ async def test_flash_llama_grammar_tools_stream(
assert response.choices[0].delta.content is None
assert tool_calls_generated == '{ "location": "Paris, France", "format": "celsius"}'
assert count == 16
assert count == 17
assert last_response == response_snapshot
@ -360,7 +360,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
)
last_response = response
assert count == 23
assert count == 24
assert (
tool_calls_generated
== '{ "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}'
@ -458,7 +458,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
tool_calls_generated += tool_call["function"]["arguments"]
last_response = response
assert count == 25
assert count == 26
assert (
tool_calls_generated
== '{ "location": "San Francisco, CA", "format": "celsius", "num_days": 3}'

View File

@ -1305,6 +1305,40 @@ pub(crate) async fn chat_completions(
state = StreamState::Content {
skip_close_quote: false,
};
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let tool_delta_start = ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
tool_calls: vec![DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: Some(global_function_name.clone()),
arguments: "".to_string(),
},
}],
});
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk{
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![ChatCompletionChoice {
index: 0,
delta: tool_delta_start,
logprobs: None,
finish_reason: None,
}],
usage: None,
});
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
buffer.drain(1..); // only keep the first token (opening '{')
buffer[0].token.text = buffer[0].token.text.chars().take(1).collect();
}
@ -1341,7 +1375,7 @@ pub(crate) async fn chat_completions(
None,
None,
None,
Some(global_function_name.clone()),
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
@ -1370,7 +1404,7 @@ pub(crate) async fn chat_completions(
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
Some(global_function_name.clone()),
None,
);
yield Ok::<Event, Infallible>(event);
@ -1394,7 +1428,7 @@ pub(crate) async fn chat_completions(
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
Some(global_function_name.clone()),
None,
);
yield Ok::<Event, Infallible>(event);
} else {
@ -1407,7 +1441,7 @@ pub(crate) async fn chat_completions(
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
Some(global_function_name.clone()),
None,
);
yield Ok::<Event, Infallible>(event);
}