mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-12 12:22:07 +00:00
fix: only send function name on first stream event
This commit is contained in:
parent
68aa6b1af0
commit
538456ba68
@ -20,7 +20,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1739799458,
|
"created": 1739910558,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
@ -20,7 +20,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1739797595,
|
"created": 1739910826,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1739456930,
|
"created": 1739910816,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1739367874,
|
"created": 1739910803,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
@ -101,7 +101,9 @@ async def test_openai_llama_tools(openai_llama_tools, response_snapshot):
|
|||||||
|
|
||||||
tool_call_string = ""
|
tool_call_string = ""
|
||||||
for chunk in chat_completion:
|
for chunk in chat_completion:
|
||||||
tool_call_string += chunk.choices[0].delta.tool_calls[0].function.arguments
|
function_call = chunk.choices[0].delta.tool_calls[0].function
|
||||||
|
if function_call:
|
||||||
|
tool_call_string += function_call.arguments
|
||||||
last_chunk = chunk.to_dict()
|
last_chunk = chunk.to_dict()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
@ -216,7 +216,7 @@ async def test_flash_llama_grammar_tools_stream(
|
|||||||
assert response.choices[0].delta.content is None
|
assert response.choices[0].delta.content is None
|
||||||
|
|
||||||
assert tool_calls_generated == '{ "location": "Paris, France", "format": "celsius"}'
|
assert tool_calls_generated == '{ "location": "Paris, France", "format": "celsius"}'
|
||||||
assert count == 16
|
assert count == 17
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -360,7 +360,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
|
|||||||
)
|
)
|
||||||
last_response = response
|
last_response = response
|
||||||
|
|
||||||
assert count == 23
|
assert count == 24
|
||||||
assert (
|
assert (
|
||||||
tool_calls_generated
|
tool_calls_generated
|
||||||
== '{ "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}'
|
== '{ "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}'
|
||||||
@ -458,7 +458,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
|||||||
tool_calls_generated += tool_call["function"]["arguments"]
|
tool_calls_generated += tool_call["function"]["arguments"]
|
||||||
last_response = response
|
last_response = response
|
||||||
|
|
||||||
assert count == 25
|
assert count == 26
|
||||||
assert (
|
assert (
|
||||||
tool_calls_generated
|
tool_calls_generated
|
||||||
== '{ "location": "San Francisco, CA", "format": "celsius", "num_days": 3}'
|
== '{ "location": "San Francisco, CA", "format": "celsius", "num_days": 3}'
|
||||||
|
@ -1305,6 +1305,40 @@ pub(crate) async fn chat_completions(
|
|||||||
state = StreamState::Content {
|
state = StreamState::Content {
|
||||||
skip_close_quote: false,
|
skip_close_quote: false,
|
||||||
};
|
};
|
||||||
|
let event = Event::default();
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
|
.as_secs();
|
||||||
|
let tool_delta_start = ChatCompletionDelta::Tool(ToolCallDelta {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
tool_calls: vec![DeltaToolCall {
|
||||||
|
index: 0,
|
||||||
|
id: String::new(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
name: Some(global_function_name.clone()),
|
||||||
|
arguments: "".to_string(),
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
let chat_complete =
|
||||||
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk{
|
||||||
|
id: String::new(),
|
||||||
|
created: current_time,
|
||||||
|
model: model_id.clone(),
|
||||||
|
system_fingerprint: system_fingerprint.clone(),
|
||||||
|
choices: vec![ChatCompletionChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: tool_delta_start,
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: None,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
});
|
||||||
|
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||||
|
InferError::StreamSerializationError(e.to_string()).into()
|
||||||
|
}));
|
||||||
buffer.drain(1..); // only keep the first token (opening '{')
|
buffer.drain(1..); // only keep the first token (opening '{')
|
||||||
buffer[0].token.text = buffer[0].token.text.chars().take(1).collect();
|
buffer[0].token.text = buffer[0].token.text.chars().take(1).collect();
|
||||||
}
|
}
|
||||||
@ -1341,7 +1375,7 @@ pub(crate) async fn chat_completions(
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
Some(global_function_name.clone()),
|
None,
|
||||||
));
|
));
|
||||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||||
InferError::StreamSerializationError(e.to_string()).into()
|
InferError::StreamSerializationError(e.to_string()).into()
|
||||||
@ -1370,7 +1404,7 @@ pub(crate) async fn chat_completions(
|
|||||||
response_as_tool,
|
response_as_tool,
|
||||||
system_fingerprint.clone(),
|
system_fingerprint.clone(),
|
||||||
model_id.clone(),
|
model_id.clone(),
|
||||||
Some(global_function_name.clone()),
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
yield Ok::<Event, Infallible>(event);
|
||||||
@ -1394,7 +1428,7 @@ pub(crate) async fn chat_completions(
|
|||||||
response_as_tool,
|
response_as_tool,
|
||||||
system_fingerprint.clone(),
|
system_fingerprint.clone(),
|
||||||
model_id.clone(),
|
model_id.clone(),
|
||||||
Some(global_function_name.clone()),
|
None,
|
||||||
);
|
);
|
||||||
yield Ok::<Event, Infallible>(event);
|
yield Ok::<Event, Infallible>(event);
|
||||||
} else {
|
} else {
|
||||||
@ -1407,7 +1441,7 @@ pub(crate) async fn chat_completions(
|
|||||||
response_as_tool,
|
response_as_tool,
|
||||||
system_fingerprint.clone(),
|
system_fingerprint.clone(),
|
||||||
model_id.clone(),
|
model_id.clone(),
|
||||||
Some(global_function_name.clone()),
|
None,
|
||||||
);
|
);
|
||||||
yield Ok::<Event, Infallible>(event);
|
yield Ok::<Event, Infallible>(event);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user