feat: refactor and simplify chat stream more, bump tests and support stream_options

This commit is contained in:
drbh 2025-02-25 20:55:56 +00:00
parent c4cb54c23e
commit a5ddc9db52
6 changed files with 263 additions and 142 deletions

View File

@ -269,6 +269,8 @@ class ResponseComparator(JSONSnapshotExtension):
def eq_chat_complete_chunk( def eq_chat_complete_chunk(
response: ChatCompletionChunk, other: ChatCompletionChunk response: ChatCompletionChunk, other: ChatCompletionChunk
) -> bool: ) -> bool:
if len(response.choices) == 0:
return len(other.choices) == 0
return response.choices[0].delta.content == other.choices[0].delta.content return response.choices[0].delta.content == other.choices[0].delta.content
def eq_response(response: Response, other: Response) -> bool: def eq_response(response: Response, other: Response) -> bool:

View File

@ -12,11 +12,11 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1726656043, "created": 1740516693,
"id": "", "id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": null "usage": null
}, },
{ {
@ -32,11 +32,11 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1726656043, "created": 1740516693,
"id": "", "id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": null "usage": null
}, },
{ {
@ -52,11 +52,11 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1726656043, "created": 1740516693,
"id": "", "id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": null "usage": null
}, },
{ {
@ -72,11 +72,11 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1726656043, "created": 1740516694,
"id": "", "id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": null "usage": null
}, },
{ {
@ -92,11 +92,11 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1726656043, "created": 1740516694,
"id": "", "id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": null "usage": null
}, },
{ {
@ -112,11 +112,11 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1726656043, "created": 1740516694,
"id": "", "id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": null "usage": null
}, },
{ {
@ -132,11 +132,11 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1726656044, "created": 1740516694,
"id": "", "id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": null "usage": null
}, },
{ {
@ -152,11 +152,11 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1726656044, "created": 1740516694,
"id": "", "id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": null "usage": null
}, },
{ {
@ -172,11 +172,11 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1726656044, "created": 1740516694,
"id": "", "id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": null "usage": null
}, },
{ {
@ -192,11 +192,20 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1726656044, "created": 1740516694,
"id": "", "id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": null
},
{
"choices": [],
"created": 1740516694,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 10, "completion_tokens": 10,
"prompt_tokens": 40, "prompt_tokens": 40,

View File

@ -5,7 +5,7 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information.", "content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from the provided information. For up-to-date information, I suggest checking a reliable weather website or app for the latest conditions and forecast.",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
@ -13,14 +13,14 @@
"usage": null "usage": null
} }
], ],
"created": 1739932427, "created": 1740516945,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "3.1.1-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 79, "completion_tokens": 83,
"prompt_tokens": 103, "prompt_tokens": 109,
"total_tokens": 182 "total_tokens": 192
} }
} }

View File

@ -91,7 +91,7 @@ async def test_flash_llama_completion_stream_usage(
index = c["choices"][0]["index"] index = c["choices"][0]["index"]
assert index == 0 assert index == 0
string += c["choices"][0]["delta"]["content"] string += c["choices"][0]["delta"]["content"]
elif len(c["choices"]) == 0:
has_usage = c["usage"] is not None has_usage = c["usage"] is not None
assert not had_usage assert not had_usage
if has_usage: if has_usage:
@ -142,7 +142,7 @@ async def test_flash_llama_completion_stream_usage(
index = c["choices"][0]["index"] index = c["choices"][0]["index"]
assert index == 0 assert index == 0
string += c["choices"][0]["delta"]["content"] string += c["choices"][0]["delta"]["content"]
elif len(c["choices"]) == 0:
has_usage = c["usage"] is not None has_usage = c["usage"] is not None
assert not had_usage assert not had_usage
if has_usage: if has_usage:

View File

@ -497,7 +497,7 @@ async def test_flash_llama_tool_reply_response(
assert responses.choices[0].message.tool_calls is None assert responses.choices[0].message.tool_calls is None
assert ( assert (
responses.choices[0].message.content responses.choices[0].message.content
== "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information." == "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from the provided information. For up-to-date information, I suggest checking a reliable weather website or app for the latest conditions and forecast."
) )
assert responses == response_snapshot assert responses == response_snapshot

View File

@ -1152,10 +1152,10 @@ fn complete_json(partial: &str) -> (String, bool) {
} }
// Generic function that parses any partial structure into a Map // Generic function that parses any partial structure into a Map
fn parse_generic_structure(partial: &str) -> Result<Map<String, Value>, String> { fn parse_generic_structure(partial: &str) -> Result<(Map<String, Value>, bool), String> {
let (completed, _) = complete_json(partial); let (completed, quote_open) = complete_json(partial);
match serde_json::from_str::<Value>(&completed) { match serde_json::from_str::<Value>(&completed) {
Ok(Value::Object(obj)) => Ok(obj), Ok(Value::Object(obj)) => Ok((obj, quote_open)),
_ => Err("Failed to parse as object".to_string()), _ => Err("Failed to parse as object".to_string()),
} }
} }
@ -1335,24 +1335,32 @@ pub(crate) async fn chat_completions(
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
let ChatRequest {
model,
stream,
logprobs,
// TODO: add back and maybe consolidate the other PR
// stream_options,
..
} = chat.clone();
let (generate_request, using_tools) = chat.try_into_generate(&infer)?;
let logprobs = logprobs.unwrap_or_default();
// extract model id from request if specified // Extract needed fields
let model_id = match model.as_deref() { let model = chat.model.clone();
Some("tgi") | None => info.model_id.clone(), let logprobs = chat.logprobs.unwrap_or_default();
Some(m_id) => m_id.to_string(), let stream = chat.stream;
}; let stream_options = chat.stream_options.clone();
// Process request (this consumes chat)
let (generate_request, using_tools) = chat.try_into_generate(&infer)?;
// Determine model ID
let model_id = model
.as_deref()
.filter(|&m| m != "tgi")
.unwrap_or(&info.model_id)
.to_string();
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// Helper function to get current timestamp
let get_timestamp = || {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs()
};
if stream { if stream {
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), span).await; generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
@ -1364,103 +1372,198 @@ pub(crate) async fn chat_completions(
let mut no_tool_chosen = false; let mut no_tool_chosen = false;
let mut first_quote_removed = false; let mut first_quote_removed = false;
while let Some(result) = response_stream.next().await { // Process stream tokens
match result { while let Some(Ok(stream_token)) = response_stream.next().await {
Ok(stream_token) => { let token_text = stream_token.token.text.clone();
let token_text = stream_token.token.text.clone(); let mut events = Vec::new();
json_buffer.push_str(&token_text); let mut should_break = false;
if !name_found {
// since we know tools is attempting to follow a grammar we can attempt to // Get usage information
// partially parse the json_buffer to see if we can extract the function name let usage = stream_token.details.as_ref().map(|d| Usage {
if let Ok(function) = parse_partial_json(&json_buffer) { completion_tokens: d.generated_tokens,
let name = function.get("_name").and_then(|n| n.as_str()).unwrap_or("no_tool"); prompt_tokens: d.input_length,
name_found = true; total_tokens: d.input_length + d.generated_tokens,
if name == "no_tool" { });
no_tool_chosen = true;
json_buffer.clear(); json_buffer.push_str(&token_text);
json_buffer.push('{');
} else { // Phase 1: Function name discovery
let tool_name_event = create_event(&token_text, &model_id, &system_fingerprint, Some(name), false, None); if !name_found {
yield Ok::<Event, Infallible>(tool_name_event); if let Ok(function) = parse_partial_json(&json_buffer) {
let tool_open_arguments_event = create_event("{", &model_id, &system_fingerprint, None, true, None); name_found = true;
yield Ok::<Event, Infallible>(tool_open_arguments_event);
// clear the buffer as we know that the buffer is only the function let name = function
// ie: ` {"function": {"_name": "get_current_weather",` -> `{"` .get("_name")
// we need to keep the `{` to open the arguments and allow the parser to continue .and_then(|n| n.as_str())
json_buffer.clear(); .unwrap_or_default();
json_buffer.push('{'); if name == "no_tool" {
} no_tool_chosen = true;
}
} else { } else {
// Process JSON buffer and handle token text events.push(create_event(
let last_is_brace = json_buffer.ends_with('}'); &token_text,
let edited_buffer = if last_is_brace { &json_buffer[..json_buffer.len() - 1] } else { &json_buffer }; &model_id,
let mut token_text = stream_token.token.text.clone(); &system_fingerprint,
let is_json_complete = serde_json::from_str::<Value>(edited_buffer).is_ok(); Some(name),
false,
None,
));
events.push(create_event(
"{",
&model_id,
&system_fingerprint,
None,
true,
None,
));
}
// Handle tool usage cases // Reset buffer for arguments
if using_tools { json_buffer.clear();
if no_tool_chosen { json_buffer.push('{');
// Tool without selection ("content" flow) }
if let Ok(function) = parse_generic_structure(edited_buffer) {
if function.get("content").and_then(|c| c.as_str()).is_some() {
// Handle quotation marks
if !first_quote_removed {
first_quote_removed = true;
if token_text == "\"" || token_text == " \"" { continue; }
token_text = token_text.replace("\"", "");
} else if token_text.ends_with('"') {
token_text = token_text[..token_text.len() - 1].to_string();
}
if is_json_complete { break; } for event in events {
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None)); yield Ok::<Event, Infallible>(event);
continue; }
} continue;
} }
continue;
// Phase 2: Content processing
let is_complete_json = json_buffer.ends_with('}')
&& serde_json::from_str::<Value>(&json_buffer[..json_buffer.len() - 1]).is_ok();
let mut edited_token = token_text;
// Handle different flows based on context
if using_tools {
if no_tool_chosen && !is_complete_json {
// Content-only flow
if let Ok((function, quote_open)) = parse_generic_structure(&json_buffer) {
if let Some(_content) = function.get("content").and_then(|c| c.as_str()) {
let cleaned_token = if !first_quote_removed {
// trim start unil the first quote
first_quote_removed = true;
edited_token
.trim_start()
.strip_prefix('"')
.unwrap_or(&edited_token)
.to_string()
} else if !quote_open {
should_break = true;
// trim end until the last quote
edited_token
.trim_end()
.strip_suffix('"')
.unwrap_or(&edited_token)
.to_string()
} else { } else {
// Tool with selection edited_token.to_string()
if is_json_complete { };
// Final token with possible brace removal
if last_is_brace { token_text = token_text[..token_text.len() - 1].to_string(); } if !cleaned_token.is_empty() {
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None)); events.push(create_event(
break; &cleaned_token,
} &model_id,
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None)); &system_fingerprint,
continue; None,
false,
None,
));
} }
} else {
// Default case: standard chat completion
if let Some(details) = stream_token.details.as_ref() {
// Handle final token and only send text if ended because of length
let text = if details.finish_reason == FinishReason::Length { &token_text } else { "" };
yield Ok::<Event, Infallible>(create_event(text, &model_id, &system_fingerprint, None, false, Some(details.finish_reason.format(true))));
break;
}
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
} }
} }
} else {
// Tool with arguments flow
if is_complete_json {
edited_token.truncate(edited_token.len() - 1);
should_break = true;
}
events.push(create_event(
&edited_token,
&model_id,
&system_fingerprint,
None,
true,
None,
));
} }
Err(err) => yield Ok(err.into_openai_event()), } else {
// Standard chat completion flow
if let Some(details) = stream_token.details.as_ref() {
let finish_reason = details.finish_reason.format(true);
let text = if details.finish_reason == FinishReason::Length {
&edited_token
} else {
""
};
events.push(create_event(
text,
&model_id,
&system_fingerprint,
None,
false,
Some(finish_reason),
));
should_break = true;
} else {
events.push(create_event(
&edited_token,
&model_id,
&system_fingerprint,
None,
false,
None,
));
}
}
// Emit all collected events
for event in events {
yield Ok::<Event, Infallible>(event);
}
// Emit usage data when requested
if let (Some(usage_data), true) = (
usage,
stream_options.as_ref().is_some_and(|o| o.include_usage)
) {
let current_time = get_timestamp();
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk {
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![],
usage: Some(usage_data),
});
yield Ok(Event::default()
.json_data(chat_complete)
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()));
}
if should_break {
break;
} }
} }
// Handle any errors in the stream
if let Some(Err(err)) = response_stream.next().await {
yield Ok(err.into_openai_event());
}
// Send final completion signal
yield Ok::<Event, Infallible>(Event::default().data("[DONE]")); yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
}; };
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
// Non-streaming case // Non-streaming response path
let (headers, input_length, Json(generation)) = let (headers, input_length, Json(generation)) =
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?; generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let (tool_calls, output) = if using_tools { let (tool_calls, output) = if using_tools {
// Parse generated JSON text
let gen_text_value: Value = let gen_text_value: Value =
serde_json::from_str(&generation.generated_text).map_err(|e| { serde_json::from_str(&generation.generated_text).map_err(|e| {
InferError::ToolError(format!( InferError::ToolError(format!(
@ -1468,6 +1571,8 @@ pub(crate) async fn chat_completions(
e, generation.generated_text e, generation.generated_text
)) ))
})?; })?;
// Extract function details
let function = gen_text_value.get("function").ok_or(InferError::ToolError( let function = gen_text_value.get("function").ok_or(InferError::ToolError(
"No function found in generated text".to_string(), "No function found in generated text".to_string(),
))?; ))?;
@ -1480,24 +1585,28 @@ pub(crate) async fn chat_completions(
))? ))?
.to_string(); .to_string();
// Prepare arguments (clone and remove _name)
let mut arguments = function.clone(); let mut arguments = function.clone();
if let Value::Object(ref mut props) = arguments { if let Value::Object(ref mut props) = arguments {
props.remove("_name"); props.remove("_name");
} }
// Process based on tool name
match name.as_str() { match name.as_str() {
"no_tool" => { "no_tool" => {
// parse the content message // Extract content for no-tool case
let content_message = arguments let content = arguments
.get("content") .get("content")
.and_then(Value::as_str) .and_then(Value::as_str)
.ok_or(InferError::ToolError( .ok_or(InferError::ToolError(
"No `content` found in generated text".to_string(), "No `content` found in generated text".to_string(),
))? ))?
.to_string(); .to_string();
(None, Some(content_message)) (None, Some(content))
} }
_ => { _ => {
let tool_calls = vec![ToolCall { // Create tool call for normal function case
let tool_call = ToolCall {
id: "0".to_string(), id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
@ -1505,26 +1614,27 @@ pub(crate) async fn chat_completions(
name, name,
arguments, arguments,
}, },
}]; };
(Some(tool_calls), None) (Some(vec![tool_call]), None)
} }
} }
} else { } else {
// Standard text output
(None, Some(generation.generated_text)) (None, Some(generation.generated_text))
}; };
// build the complete response object with the full text
// Build complete response with all details
let response = CompletionType::ChatCompletion(ChatCompletion::new( let response = CompletionType::ChatCompletion(ChatCompletion::new(
model_id, model_id,
system_fingerprint, system_fingerprint,
output, output,
current_time, get_timestamp(),
generation.details.unwrap(), generation.details.unwrap(),
logprobs, logprobs,
tool_calls, tool_calls,
input_length, input_length,
)); ));
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(response)).into_response()) Ok((headers, Json(response)).into_response())
} }
} }