mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 21:12:07 +00:00
feat: refactor and simplify chat stream more, bump tests and support stream_options
This commit is contained in:
parent
c4cb54c23e
commit
a5ddc9db52
@ -269,6 +269,8 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
def eq_chat_complete_chunk(
|
def eq_chat_complete_chunk(
|
||||||
response: ChatCompletionChunk, other: ChatCompletionChunk
|
response: ChatCompletionChunk, other: ChatCompletionChunk
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
if len(response.choices) == 0:
|
||||||
|
return len(other.choices) == 0
|
||||||
return response.choices[0].delta.content == other.choices[0].delta.content
|
return response.choices[0].delta.content == other.choices[0].delta.content
|
||||||
|
|
||||||
def eq_response(response: Response, other: Response) -> bool:
|
def eq_response(response: Response, other: Response) -> bool:
|
||||||
|
@ -12,11 +12,11 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1726656043,
|
"created": 1740516693,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.2.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -32,11 +32,11 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1726656043,
|
"created": 1740516693,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.2.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -52,11 +52,11 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1726656043,
|
"created": 1740516693,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.2.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -72,11 +72,11 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1726656043,
|
"created": 1740516694,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.2.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -92,11 +92,11 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1726656043,
|
"created": 1740516694,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.2.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -112,11 +112,11 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1726656043,
|
"created": 1740516694,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.2.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -132,11 +132,11 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1726656044,
|
"created": 1740516694,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.2.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -152,11 +152,11 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1726656044,
|
"created": 1740516694,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.2.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -172,11 +172,11 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1726656044,
|
"created": 1740516694,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.2.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -192,11 +192,20 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1726656044,
|
"created": 1740516694,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.2.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [],
|
||||||
|
"created": 1740516694,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 10,
|
"completion_tokens": 10,
|
||||||
"prompt_tokens": 40,
|
"prompt_tokens": 40,
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
"content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information.",
|
"content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from the provided information. For up-to-date information, I suggest checking a reliable weather website or app for the latest conditions and forecast.",
|
||||||
"name": null,
|
"name": null,
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
@ -13,14 +13,14 @@
|
|||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1739932427,
|
"created": 1740516945,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"system_fingerprint": "3.1.1-dev0-native",
|
"system_fingerprint": "3.1.1-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 79,
|
"completion_tokens": 83,
|
||||||
"prompt_tokens": 103,
|
"prompt_tokens": 109,
|
||||||
"total_tokens": 182
|
"total_tokens": 192
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -91,7 +91,7 @@ async def test_flash_llama_completion_stream_usage(
|
|||||||
index = c["choices"][0]["index"]
|
index = c["choices"][0]["index"]
|
||||||
assert index == 0
|
assert index == 0
|
||||||
string += c["choices"][0]["delta"]["content"]
|
string += c["choices"][0]["delta"]["content"]
|
||||||
|
elif len(c["choices"]) == 0:
|
||||||
has_usage = c["usage"] is not None
|
has_usage = c["usage"] is not None
|
||||||
assert not had_usage
|
assert not had_usage
|
||||||
if has_usage:
|
if has_usage:
|
||||||
@ -142,7 +142,7 @@ async def test_flash_llama_completion_stream_usage(
|
|||||||
index = c["choices"][0]["index"]
|
index = c["choices"][0]["index"]
|
||||||
assert index == 0
|
assert index == 0
|
||||||
string += c["choices"][0]["delta"]["content"]
|
string += c["choices"][0]["delta"]["content"]
|
||||||
|
elif len(c["choices"]) == 0:
|
||||||
has_usage = c["usage"] is not None
|
has_usage = c["usage"] is not None
|
||||||
assert not had_usage
|
assert not had_usage
|
||||||
if has_usage:
|
if has_usage:
|
||||||
|
@ -497,7 +497,7 @@ async def test_flash_llama_tool_reply_response(
|
|||||||
assert responses.choices[0].message.tool_calls is None
|
assert responses.choices[0].message.tool_calls is None
|
||||||
assert (
|
assert (
|
||||||
responses.choices[0].message.content
|
responses.choices[0].message.content
|
||||||
== "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information."
|
== "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from the provided information. For up-to-date information, I suggest checking a reliable weather website or app for the latest conditions and forecast."
|
||||||
)
|
)
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
@ -1152,10 +1152,10 @@ fn complete_json(partial: &str) -> (String, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generic function that parses any partial structure into a Map
|
// Generic function that parses any partial structure into a Map
|
||||||
fn parse_generic_structure(partial: &str) -> Result<Map<String, Value>, String> {
|
fn parse_generic_structure(partial: &str) -> Result<(Map<String, Value>, bool), String> {
|
||||||
let (completed, _) = complete_json(partial);
|
let (completed, quote_open) = complete_json(partial);
|
||||||
match serde_json::from_str::<Value>(&completed) {
|
match serde_json::from_str::<Value>(&completed) {
|
||||||
Ok(Value::Object(obj)) => Ok(obj),
|
Ok(Value::Object(obj)) => Ok((obj, quote_open)),
|
||||||
_ => Err("Failed to parse as object".to_string()),
|
_ => Err("Failed to parse as object".to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1335,24 +1335,32 @@ pub(crate) async fn chat_completions(
|
|||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
let ChatRequest {
|
|
||||||
model,
|
|
||||||
stream,
|
|
||||||
logprobs,
|
|
||||||
// TODO: add back and maybe consolidate the other PR
|
|
||||||
// stream_options,
|
|
||||||
..
|
|
||||||
} = chat.clone();
|
|
||||||
let (generate_request, using_tools) = chat.try_into_generate(&infer)?;
|
|
||||||
let logprobs = logprobs.unwrap_or_default();
|
|
||||||
|
|
||||||
// extract model id from request if specified
|
// Extract needed fields
|
||||||
let model_id = match model.as_deref() {
|
let model = chat.model.clone();
|
||||||
Some("tgi") | None => info.model_id.clone(),
|
let logprobs = chat.logprobs.unwrap_or_default();
|
||||||
Some(m_id) => m_id.to_string(),
|
let stream = chat.stream;
|
||||||
};
|
let stream_options = chat.stream_options.clone();
|
||||||
|
|
||||||
|
// Process request (this consumes chat)
|
||||||
|
let (generate_request, using_tools) = chat.try_into_generate(&infer)?;
|
||||||
|
|
||||||
|
// Determine model ID
|
||||||
|
let model_id = model
|
||||||
|
.as_deref()
|
||||||
|
.filter(|&m| m != "tgi")
|
||||||
|
.unwrap_or(&info.model_id)
|
||||||
|
.to_string();
|
||||||
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||||
|
|
||||||
|
// Helper function to get current timestamp
|
||||||
|
let get_timestamp = || {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
|
.as_secs()
|
||||||
|
};
|
||||||
|
|
||||||
if stream {
|
if stream {
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||||
@ -1364,103 +1372,198 @@ pub(crate) async fn chat_completions(
|
|||||||
let mut no_tool_chosen = false;
|
let mut no_tool_chosen = false;
|
||||||
let mut first_quote_removed = false;
|
let mut first_quote_removed = false;
|
||||||
|
|
||||||
while let Some(result) = response_stream.next().await {
|
// Process stream tokens
|
||||||
match result {
|
while let Some(Ok(stream_token)) = response_stream.next().await {
|
||||||
Ok(stream_token) => {
|
let token_text = stream_token.token.text.clone();
|
||||||
let token_text = stream_token.token.text.clone();
|
let mut events = Vec::new();
|
||||||
json_buffer.push_str(&token_text);
|
let mut should_break = false;
|
||||||
if !name_found {
|
|
||||||
// since we know tools is attempting to follow a grammar we can attempt to
|
// Get usage information
|
||||||
// partially parse the json_buffer to see if we can extract the function name
|
let usage = stream_token.details.as_ref().map(|d| Usage {
|
||||||
if let Ok(function) = parse_partial_json(&json_buffer) {
|
completion_tokens: d.generated_tokens,
|
||||||
let name = function.get("_name").and_then(|n| n.as_str()).unwrap_or("no_tool");
|
prompt_tokens: d.input_length,
|
||||||
name_found = true;
|
total_tokens: d.input_length + d.generated_tokens,
|
||||||
if name == "no_tool" {
|
});
|
||||||
no_tool_chosen = true;
|
|
||||||
json_buffer.clear();
|
json_buffer.push_str(&token_text);
|
||||||
json_buffer.push('{');
|
|
||||||
} else {
|
// Phase 1: Function name discovery
|
||||||
let tool_name_event = create_event(&token_text, &model_id, &system_fingerprint, Some(name), false, None);
|
if !name_found {
|
||||||
yield Ok::<Event, Infallible>(tool_name_event);
|
if let Ok(function) = parse_partial_json(&json_buffer) {
|
||||||
let tool_open_arguments_event = create_event("{", &model_id, &system_fingerprint, None, true, None);
|
name_found = true;
|
||||||
yield Ok::<Event, Infallible>(tool_open_arguments_event);
|
|
||||||
// clear the buffer as we know that the buffer is only the function
|
let name = function
|
||||||
// ie: ` {"function": {"_name": "get_current_weather",` -> `{"`
|
.get("_name")
|
||||||
// we need to keep the `{` to open the arguments and allow the parser to continue
|
.and_then(|n| n.as_str())
|
||||||
json_buffer.clear();
|
.unwrap_or_default();
|
||||||
json_buffer.push('{');
|
if name == "no_tool" {
|
||||||
}
|
no_tool_chosen = true;
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// Process JSON buffer and handle token text
|
events.push(create_event(
|
||||||
let last_is_brace = json_buffer.ends_with('}');
|
&token_text,
|
||||||
let edited_buffer = if last_is_brace { &json_buffer[..json_buffer.len() - 1] } else { &json_buffer };
|
&model_id,
|
||||||
let mut token_text = stream_token.token.text.clone();
|
&system_fingerprint,
|
||||||
let is_json_complete = serde_json::from_str::<Value>(edited_buffer).is_ok();
|
Some(name),
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
));
|
||||||
|
events.push(create_event(
|
||||||
|
"{",
|
||||||
|
&model_id,
|
||||||
|
&system_fingerprint,
|
||||||
|
None,
|
||||||
|
true,
|
||||||
|
None,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// Handle tool usage cases
|
// Reset buffer for arguments
|
||||||
if using_tools {
|
json_buffer.clear();
|
||||||
if no_tool_chosen {
|
json_buffer.push('{');
|
||||||
// Tool without selection ("content" flow)
|
}
|
||||||
if let Ok(function) = parse_generic_structure(edited_buffer) {
|
|
||||||
if function.get("content").and_then(|c| c.as_str()).is_some() {
|
|
||||||
// Handle quotation marks
|
|
||||||
if !first_quote_removed {
|
|
||||||
first_quote_removed = true;
|
|
||||||
if token_text == "\"" || token_text == " \"" { continue; }
|
|
||||||
token_text = token_text.replace("\"", "");
|
|
||||||
} else if token_text.ends_with('"') {
|
|
||||||
token_text = token_text[..token_text.len() - 1].to_string();
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_json_complete { break; }
|
for event in events {
|
||||||
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
|
yield Ok::<Event, Infallible>(event);
|
||||||
continue;
|
}
|
||||||
}
|
continue;
|
||||||
}
|
}
|
||||||
continue;
|
|
||||||
|
// Phase 2: Content processing
|
||||||
|
let is_complete_json = json_buffer.ends_with('}')
|
||||||
|
&& serde_json::from_str::<Value>(&json_buffer[..json_buffer.len() - 1]).is_ok();
|
||||||
|
let mut edited_token = token_text;
|
||||||
|
|
||||||
|
// Handle different flows based on context
|
||||||
|
if using_tools {
|
||||||
|
if no_tool_chosen && !is_complete_json {
|
||||||
|
// Content-only flow
|
||||||
|
if let Ok((function, quote_open)) = parse_generic_structure(&json_buffer) {
|
||||||
|
if let Some(_content) = function.get("content").and_then(|c| c.as_str()) {
|
||||||
|
let cleaned_token = if !first_quote_removed {
|
||||||
|
// trim start unil the first quote
|
||||||
|
first_quote_removed = true;
|
||||||
|
edited_token
|
||||||
|
.trim_start()
|
||||||
|
.strip_prefix('"')
|
||||||
|
.unwrap_or(&edited_token)
|
||||||
|
.to_string()
|
||||||
|
} else if !quote_open {
|
||||||
|
should_break = true;
|
||||||
|
// trim end until the last quote
|
||||||
|
edited_token
|
||||||
|
.trim_end()
|
||||||
|
.strip_suffix('"')
|
||||||
|
.unwrap_or(&edited_token)
|
||||||
|
.to_string()
|
||||||
} else {
|
} else {
|
||||||
// Tool with selection
|
edited_token.to_string()
|
||||||
if is_json_complete {
|
};
|
||||||
// Final token with possible brace removal
|
|
||||||
if last_is_brace { token_text = token_text[..token_text.len() - 1].to_string(); }
|
if !cleaned_token.is_empty() {
|
||||||
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
|
events.push(create_event(
|
||||||
break;
|
&cleaned_token,
|
||||||
}
|
&model_id,
|
||||||
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
|
&system_fingerprint,
|
||||||
continue;
|
None,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Default case: standard chat completion
|
|
||||||
if let Some(details) = stream_token.details.as_ref() {
|
|
||||||
// Handle final token and only send text if ended because of length
|
|
||||||
let text = if details.finish_reason == FinishReason::Length { &token_text } else { "" };
|
|
||||||
yield Ok::<Event, Infallible>(create_event(text, &model_id, &system_fingerprint, None, false, Some(details.finish_reason.format(true))));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// Tool with arguments flow
|
||||||
|
if is_complete_json {
|
||||||
|
edited_token.truncate(edited_token.len() - 1);
|
||||||
|
should_break = true;
|
||||||
|
}
|
||||||
|
events.push(create_event(
|
||||||
|
&edited_token,
|
||||||
|
&model_id,
|
||||||
|
&system_fingerprint,
|
||||||
|
None,
|
||||||
|
true,
|
||||||
|
None,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
Err(err) => yield Ok(err.into_openai_event()),
|
} else {
|
||||||
|
// Standard chat completion flow
|
||||||
|
if let Some(details) = stream_token.details.as_ref() {
|
||||||
|
let finish_reason = details.finish_reason.format(true);
|
||||||
|
let text = if details.finish_reason == FinishReason::Length {
|
||||||
|
&edited_token
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
events.push(create_event(
|
||||||
|
text,
|
||||||
|
&model_id,
|
||||||
|
&system_fingerprint,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
Some(finish_reason),
|
||||||
|
));
|
||||||
|
should_break = true;
|
||||||
|
} else {
|
||||||
|
events.push(create_event(
|
||||||
|
&edited_token,
|
||||||
|
&model_id,
|
||||||
|
&system_fingerprint,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit all collected events
|
||||||
|
for event in events {
|
||||||
|
yield Ok::<Event, Infallible>(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit usage data when requested
|
||||||
|
if let (Some(usage_data), true) = (
|
||||||
|
usage,
|
||||||
|
stream_options.as_ref().is_some_and(|o| o.include_usage)
|
||||||
|
) {
|
||||||
|
let current_time = get_timestamp();
|
||||||
|
|
||||||
|
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
||||||
|
id: String::new(),
|
||||||
|
created: current_time,
|
||||||
|
model: model_id.clone(),
|
||||||
|
system_fingerprint: system_fingerprint.clone(),
|
||||||
|
choices: vec![],
|
||||||
|
usage: Some(usage_data),
|
||||||
|
});
|
||||||
|
|
||||||
|
yield Ok(Event::default()
|
||||||
|
.json_data(chat_complete)
|
||||||
|
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()));
|
||||||
|
}
|
||||||
|
if should_break {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle any errors in the stream
|
||||||
|
if let Some(Err(err)) = response_stream.next().await {
|
||||||
|
yield Ok(err.into_openai_event());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final completion signal
|
||||||
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||||
};
|
};
|
||||||
|
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
// Non-streaming case
|
// Non-streaming response path
|
||||||
let (headers, input_length, Json(generation)) =
|
let (headers, input_length, Json(generation)) =
|
||||||
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.unwrap_or_default()
|
|
||||||
.as_secs();
|
|
||||||
|
|
||||||
let (tool_calls, output) = if using_tools {
|
let (tool_calls, output) = if using_tools {
|
||||||
|
// Parse generated JSON text
|
||||||
let gen_text_value: Value =
|
let gen_text_value: Value =
|
||||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||||
InferError::ToolError(format!(
|
InferError::ToolError(format!(
|
||||||
@ -1468,6 +1571,8 @@ pub(crate) async fn chat_completions(
|
|||||||
e, generation.generated_text
|
e, generation.generated_text
|
||||||
))
|
))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
// Extract function details
|
||||||
let function = gen_text_value.get("function").ok_or(InferError::ToolError(
|
let function = gen_text_value.get("function").ok_or(InferError::ToolError(
|
||||||
"No function found in generated text".to_string(),
|
"No function found in generated text".to_string(),
|
||||||
))?;
|
))?;
|
||||||
@ -1480,24 +1585,28 @@ pub(crate) async fn chat_completions(
|
|||||||
))?
|
))?
|
||||||
.to_string();
|
.to_string();
|
||||||
|
|
||||||
|
// Prepare arguments (clone and remove _name)
|
||||||
let mut arguments = function.clone();
|
let mut arguments = function.clone();
|
||||||
if let Value::Object(ref mut props) = arguments {
|
if let Value::Object(ref mut props) = arguments {
|
||||||
props.remove("_name");
|
props.remove("_name");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process based on tool name
|
||||||
match name.as_str() {
|
match name.as_str() {
|
||||||
"no_tool" => {
|
"no_tool" => {
|
||||||
// parse the content message
|
// Extract content for no-tool case
|
||||||
let content_message = arguments
|
let content = arguments
|
||||||
.get("content")
|
.get("content")
|
||||||
.and_then(Value::as_str)
|
.and_then(Value::as_str)
|
||||||
.ok_or(InferError::ToolError(
|
.ok_or(InferError::ToolError(
|
||||||
"No `content` found in generated text".to_string(),
|
"No `content` found in generated text".to_string(),
|
||||||
))?
|
))?
|
||||||
.to_string();
|
.to_string();
|
||||||
(None, Some(content_message))
|
(None, Some(content))
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let tool_calls = vec![ToolCall {
|
// Create tool call for normal function case
|
||||||
|
let tool_call = ToolCall {
|
||||||
id: "0".to_string(),
|
id: "0".to_string(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: FunctionDefinition {
|
function: FunctionDefinition {
|
||||||
@ -1505,26 +1614,27 @@ pub(crate) async fn chat_completions(
|
|||||||
name,
|
name,
|
||||||
arguments,
|
arguments,
|
||||||
},
|
},
|
||||||
}];
|
};
|
||||||
(Some(tool_calls), None)
|
(Some(vec![tool_call]), None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// Standard text output
|
||||||
(None, Some(generation.generated_text))
|
(None, Some(generation.generated_text))
|
||||||
};
|
};
|
||||||
// build the complete response object with the full text
|
|
||||||
|
// Build complete response with all details
|
||||||
let response = CompletionType::ChatCompletion(ChatCompletion::new(
|
let response = CompletionType::ChatCompletion(ChatCompletion::new(
|
||||||
model_id,
|
model_id,
|
||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
output,
|
output,
|
||||||
current_time,
|
get_timestamp(),
|
||||||
generation.details.unwrap(),
|
generation.details.unwrap(),
|
||||||
logprobs,
|
logprobs,
|
||||||
tool_calls,
|
tool_calls,
|
||||||
input_length,
|
input_length,
|
||||||
));
|
));
|
||||||
|
|
||||||
// wrap generation inside a Vec to match api-inference
|
|
||||||
Ok((headers, Json(response)).into_response())
|
Ok((headers, Json(response)).into_response())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user