feat: refactor and simplify chat stream more, bump tests and support stream_options

This commit is contained in:
drbh 2025-02-25 20:55:56 +00:00
parent c4cb54c23e
commit a5ddc9db52
6 changed files with 263 additions and 142 deletions

View File

@ -269,6 +269,8 @@ class ResponseComparator(JSONSnapshotExtension):
def eq_chat_complete_chunk(
response: ChatCompletionChunk, other: ChatCompletionChunk
) -> bool:
if len(response.choices) == 0:
return len(other.choices) == 0
return response.choices[0].delta.content == other.choices[0].delta.content
def eq_response(response: Response, other: Response) -> bool:

View File

@ -12,11 +12,11 @@
"logprobs": null
}
],
"created": 1726656043,
"created": 1740516693,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
},
{
@ -32,11 +32,11 @@
"logprobs": null
}
],
"created": 1726656043,
"created": 1740516693,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
},
{
@ -52,11 +52,11 @@
"logprobs": null
}
],
"created": 1726656043,
"created": 1740516693,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
},
{
@ -72,11 +72,11 @@
"logprobs": null
}
],
"created": 1726656043,
"created": 1740516694,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
},
{
@ -92,11 +92,11 @@
"logprobs": null
}
],
"created": 1726656043,
"created": 1740516694,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
},
{
@ -112,11 +112,11 @@
"logprobs": null
}
],
"created": 1726656043,
"created": 1740516694,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
},
{
@ -132,11 +132,11 @@
"logprobs": null
}
],
"created": 1726656044,
"created": 1740516694,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
},
{
@ -152,11 +152,11 @@
"logprobs": null
}
],
"created": 1726656044,
"created": 1740516694,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
},
{
@ -172,11 +172,11 @@
"logprobs": null
}
],
"created": 1726656044,
"created": 1740516694,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
},
{
@ -192,11 +192,20 @@
"logprobs": null
}
],
"created": 1726656044,
"created": 1740516694,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"system_fingerprint": "3.1.1-dev0-native",
"usage": null
},
{
"choices": [],
"created": 1740516694,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 40,

View File

@ -5,7 +5,7 @@
"index": 0,
"logprobs": null,
"message": {
"content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information.",
"content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from the provided information. For up-to-date information, I suggest checking a reliable weather website or app for the latest conditions and forecast.",
"name": null,
"role": "assistant",
"tool_calls": null
@ -13,14 +13,14 @@
"usage": null
}
],
"created": 1739932427,
"created": 1740516945,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.1.1-dev0-native",
"usage": {
"completion_tokens": 79,
"prompt_tokens": 103,
"total_tokens": 182
"completion_tokens": 83,
"prompt_tokens": 109,
"total_tokens": 192
}
}

View File

@ -91,7 +91,7 @@ async def test_flash_llama_completion_stream_usage(
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
elif len(c["choices"]) == 0:
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
@ -142,7 +142,7 @@ async def test_flash_llama_completion_stream_usage(
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
elif len(c["choices"]) == 0:
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:

View File

@ -497,7 +497,7 @@ async def test_flash_llama_tool_reply_response(
assert responses.choices[0].message.tool_calls is None
assert (
responses.choices[0].message.content
== "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information."
== "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from the provided information. For up-to-date information, I suggest checking a reliable weather website or app for the latest conditions and forecast."
)
assert responses == response_snapshot

View File

@ -1152,10 +1152,10 @@ fn complete_json(partial: &str) -> (String, bool) {
}
// Generic function that parses any partial structure into a Map
fn parse_generic_structure(partial: &str) -> Result<Map<String, Value>, String> {
let (completed, _) = complete_json(partial);
fn parse_generic_structure(partial: &str) -> Result<(Map<String, Value>, bool), String> {
let (completed, quote_open) = complete_json(partial);
match serde_json::from_str::<Value>(&completed) {
Ok(Value::Object(obj)) => Ok(obj),
Ok(Value::Object(obj)) => Ok((obj, quote_open)),
_ => Err("Failed to parse as object".to_string()),
}
}
@ -1335,24 +1335,32 @@ pub(crate) async fn chat_completions(
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
let ChatRequest {
model,
stream,
logprobs,
// TODO: add back and maybe consolidate the other PR
// stream_options,
..
} = chat.clone();
let (generate_request, using_tools) = chat.try_into_generate(&infer)?;
let logprobs = logprobs.unwrap_or_default();
// extract model id from request if specified
let model_id = match model.as_deref() {
Some("tgi") | None => info.model_id.clone(),
Some(m_id) => m_id.to_string(),
};
// Extract needed fields
let model = chat.model.clone();
let logprobs = chat.logprobs.unwrap_or_default();
let stream = chat.stream;
let stream_options = chat.stream_options.clone();
// Process request (this consumes chat)
let (generate_request, using_tools) = chat.try_into_generate(&infer)?;
// Determine model ID
let model_id = model
.as_deref()
.filter(|&m| m != "tgi")
.unwrap_or(&info.model_id)
.to_string();
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// Helper function to get current timestamp
let get_timestamp = || {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs()
};
if stream {
let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
@ -1364,103 +1372,198 @@ pub(crate) async fn chat_completions(
let mut no_tool_chosen = false;
let mut first_quote_removed = false;
while let Some(result) = response_stream.next().await {
match result {
Ok(stream_token) => {
let token_text = stream_token.token.text.clone();
json_buffer.push_str(&token_text);
if !name_found {
// since we know tools is attempting to follow a grammar we can attempt to
// partially parse the json_buffer to see if we can extract the function name
if let Ok(function) = parse_partial_json(&json_buffer) {
let name = function.get("_name").and_then(|n| n.as_str()).unwrap_or("no_tool");
name_found = true;
if name == "no_tool" {
no_tool_chosen = true;
json_buffer.clear();
json_buffer.push('{');
} else {
let tool_name_event = create_event(&token_text, &model_id, &system_fingerprint, Some(name), false, None);
yield Ok::<Event, Infallible>(tool_name_event);
let tool_open_arguments_event = create_event("{", &model_id, &system_fingerprint, None, true, None);
yield Ok::<Event, Infallible>(tool_open_arguments_event);
// clear the buffer as we know that the buffer is only the function
// ie: ` {"function": {"_name": "get_current_weather",` -> `{"`
// we need to keep the `{` to open the arguments and allow the parser to continue
json_buffer.clear();
json_buffer.push('{');
}
}
// Process stream tokens
while let Some(Ok(stream_token)) = response_stream.next().await {
let token_text = stream_token.token.text.clone();
let mut events = Vec::new();
let mut should_break = false;
// Get usage information
let usage = stream_token.details.as_ref().map(|d| Usage {
completion_tokens: d.generated_tokens,
prompt_tokens: d.input_length,
total_tokens: d.input_length + d.generated_tokens,
});
json_buffer.push_str(&token_text);
// Phase 1: Function name discovery
if !name_found {
if let Ok(function) = parse_partial_json(&json_buffer) {
name_found = true;
let name = function
.get("_name")
.and_then(|n| n.as_str())
.unwrap_or_default();
if name == "no_tool" {
no_tool_chosen = true;
} else {
// Process JSON buffer and handle token text
let last_is_brace = json_buffer.ends_with('}');
let edited_buffer = if last_is_brace { &json_buffer[..json_buffer.len() - 1] } else { &json_buffer };
let mut token_text = stream_token.token.text.clone();
let is_json_complete = serde_json::from_str::<Value>(edited_buffer).is_ok();
events.push(create_event(
&token_text,
&model_id,
&system_fingerprint,
Some(name),
false,
None,
));
events.push(create_event(
"{",
&model_id,
&system_fingerprint,
None,
true,
None,
));
}
// Handle tool usage cases
if using_tools {
if no_tool_chosen {
// Tool without selection ("content" flow)
if let Ok(function) = parse_generic_structure(edited_buffer) {
if function.get("content").and_then(|c| c.as_str()).is_some() {
// Handle quotation marks
if !first_quote_removed {
first_quote_removed = true;
if token_text == "\"" || token_text == " \"" { continue; }
token_text = token_text.replace("\"", "");
} else if token_text.ends_with('"') {
token_text = token_text[..token_text.len() - 1].to_string();
}
// Reset buffer for arguments
json_buffer.clear();
json_buffer.push('{');
}
if is_json_complete { break; }
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
continue;
}
}
continue;
for event in events {
yield Ok::<Event, Infallible>(event);
}
continue;
}
// Phase 2: Content processing
let is_complete_json = json_buffer.ends_with('}')
&& serde_json::from_str::<Value>(&json_buffer[..json_buffer.len() - 1]).is_ok();
let mut edited_token = token_text;
// Handle different flows based on context
if using_tools {
if no_tool_chosen && !is_complete_json {
// Content-only flow
if let Ok((function, quote_open)) = parse_generic_structure(&json_buffer) {
if let Some(_content) = function.get("content").and_then(|c| c.as_str()) {
let cleaned_token = if !first_quote_removed {
// trim start unil the first quote
first_quote_removed = true;
edited_token
.trim_start()
.strip_prefix('"')
.unwrap_or(&edited_token)
.to_string()
} else if !quote_open {
should_break = true;
// trim end until the last quote
edited_token
.trim_end()
.strip_suffix('"')
.unwrap_or(&edited_token)
.to_string()
} else {
// Tool with selection
if is_json_complete {
// Final token with possible brace removal
if last_is_brace { token_text = token_text[..token_text.len() - 1].to_string(); }
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
break;
}
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
continue;
edited_token.to_string()
};
if !cleaned_token.is_empty() {
events.push(create_event(
&cleaned_token,
&model_id,
&system_fingerprint,
None,
false,
None,
));
}
} else {
// Default case: standard chat completion
if let Some(details) = stream_token.details.as_ref() {
// Handle final token and only send text if ended because of length
let text = if details.finish_reason == FinishReason::Length { &token_text } else { "" };
yield Ok::<Event, Infallible>(create_event(text, &model_id, &system_fingerprint, None, false, Some(details.finish_reason.format(true))));
break;
}
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
}
}
} else {
// Tool with arguments flow
if is_complete_json {
edited_token.truncate(edited_token.len() - 1);
should_break = true;
}
events.push(create_event(
&edited_token,
&model_id,
&system_fingerprint,
None,
true,
None,
));
}
Err(err) => yield Ok(err.into_openai_event()),
} else {
// Standard chat completion flow
if let Some(details) = stream_token.details.as_ref() {
let finish_reason = details.finish_reason.format(true);
let text = if details.finish_reason == FinishReason::Length {
&edited_token
} else {
""
};
events.push(create_event(
text,
&model_id,
&system_fingerprint,
None,
false,
Some(finish_reason),
));
should_break = true;
} else {
events.push(create_event(
&edited_token,
&model_id,
&system_fingerprint,
None,
false,
None,
));
}
}
// Emit all collected events
for event in events {
yield Ok::<Event, Infallible>(event);
}
// Emit usage data when requested
if let (Some(usage_data), true) = (
usage,
stream_options.as_ref().is_some_and(|o| o.include_usage)
) {
let current_time = get_timestamp();
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk {
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![],
usage: Some(usage_data),
});
yield Ok(Event::default()
.json_data(chat_complete)
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()));
}
if should_break {
break;
}
}
// Handle any errors in the stream
if let Some(Err(err)) = response_stream.next().await {
yield Ok(err.into_openai_event());
}
// Send final completion signal
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
};
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
// Non-streaming case
// Non-streaming response path
let (headers, input_length, Json(generation)) =
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let (tool_calls, output) = if using_tools {
// Parse generated JSON text
let gen_text_value: Value =
serde_json::from_str(&generation.generated_text).map_err(|e| {
InferError::ToolError(format!(
@ -1468,6 +1571,8 @@ pub(crate) async fn chat_completions(
e, generation.generated_text
))
})?;
// Extract function details
let function = gen_text_value.get("function").ok_or(InferError::ToolError(
"No function found in generated text".to_string(),
))?;
@ -1480,24 +1585,28 @@ pub(crate) async fn chat_completions(
))?
.to_string();
// Prepare arguments (clone and remove _name)
let mut arguments = function.clone();
if let Value::Object(ref mut props) = arguments {
props.remove("_name");
}
// Process based on tool name
match name.as_str() {
"no_tool" => {
// parse the content message
let content_message = arguments
// Extract content for no-tool case
let content = arguments
.get("content")
.and_then(Value::as_str)
.ok_or(InferError::ToolError(
"No `content` found in generated text".to_string(),
))?
.to_string();
(None, Some(content_message))
(None, Some(content))
}
_ => {
let tool_calls = vec![ToolCall {
// Create tool call for normal function case
let tool_call = ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionDefinition {
@ -1505,26 +1614,27 @@ pub(crate) async fn chat_completions(
name,
arguments,
},
}];
(Some(tool_calls), None)
};
(Some(vec![tool_call]), None)
}
}
} else {
// Standard text output
(None, Some(generation.generated_text))
};
// build the complete response object with the full text
// Build complete response with all details
let response = CompletionType::ChatCompletion(ChatCompletion::new(
model_id,
system_fingerprint,
output,
current_time,
get_timestamp(),
generation.details.unwrap(),
logprobs,
tool_calls,
input_length,
));
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(response)).into_response())
}
}