feat: refactor chat stream to remove state machine and simplfy logic

This commit is contained in:
drbh 2025-02-24 21:51:33 +00:00
parent a416ddbdd9
commit 31a536d796
5 changed files with 255 additions and 258 deletions

View File

@ -2,7 +2,7 @@
"choices": [ "choices": [
{ {
"delta": { "delta": {
"content": " assistant", "content": "!",
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
}, },
@ -11,7 +11,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1739441937, "created": 1740432006,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",

View File

@ -2,7 +2,7 @@
"choices": [ "choices": [
{ {
"delta": { "delta": {
"content": " Oracle", "content": ".",
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
}, },
@ -11,7 +11,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1739444803, "created": 1740432012,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",

View File

@ -11,7 +11,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1739454835, "created": 1740433572,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",

View File

@ -281,8 +281,8 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
last_response = response last_response = response
assert response.choices[0].delta.tool_calls is None assert response.choices[0].delta.tool_calls is None
assert count == 5 assert count == 6
assert content_generated == "I am a helpful assistant" assert content_generated == "I am a helpful assistant!"
assert last_response == response_snapshot assert last_response == response_snapshot
@ -318,10 +318,10 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream(
last_response = response last_response = response
assert response.choices[0].delta.tool_calls is None assert response.choices[0].delta.tool_calls is None
assert count == 77 assert count == 78
assert ( assert (
content_generated content_generated
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle" == "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle."
) )
assert last_response == response_snapshot assert last_response == response_snapshot
@ -401,7 +401,6 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_none(
assert response.choices[0].delta.tool_calls is None assert response.choices[0].delta.tool_calls is None
assert count == 100 assert count == 100
print(content_generated)
assert ( assert (
content_generated content_generated
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep" == "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep"

View File

@ -47,7 +47,7 @@ use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::IntoPyDict; use pyo3::types::IntoPyDict;
use regex::Regex; use serde_json::Map;
use serde_json::Value; use serde_json::Value;
use std::convert::Infallible; use std::convert::Infallible;
use std::fs::File; use std::fs::File;
@ -1114,84 +1114,183 @@ pub(crate) async fn completions(
} }
} }
enum StreamState { // balance the started json with closing braces and quotes
Buffering, fn complete_json(partial: &str) -> (String, bool) {
BufferTrailing, let mut brace_count = 0;
Content { skip_close_quote: bool }, let mut quote_open = false;
let mut escaped = false;
let mut last_char = '\0';
for c in partial.chars() {
match (escaped, quote_open, c) {
(true, _, _) => escaped = false,
(false, _, '\\') => escaped = true,
(false, _, '"') => quote_open = !quote_open,
(false, false, '{') => brace_count += 1,
(false, false, '}') => brace_count -= 1,
_ => {}
}
if !c.is_whitespace() {
last_char = c;
}
}
let mut completed = partial.to_string();
if last_char == ',' {
if let Some(pos) = completed.rfind(',') {
completed.replace_range(pos..pos + 1, "");
}
}
if quote_open {
completed.push('"');
}
completed.push_str(&"}".repeat(brace_count.max(0)));
(completed, quote_open)
} }
/// Convert a StreamResponse into an Event to be sent over SSE // Generic function that parses any partial structure into a Map
fn create_event_from_stream_token( fn parse_generic_structure(partial: &str) -> Result<Map<String, Value>, String> {
stream_token: &StreamResponse, let (completed, _) = complete_json(partial);
logprobs: bool, match serde_json::from_str::<Value>(&completed) {
stream_options: Option<StreamOptions>, Ok(Value::Object(obj)) => Ok(obj),
inner_using_tools: bool, _ => Err("Failed to parse as object".to_string()),
system_fingerprint: String, }
model_id: String, }
tool_name: Option<String>,
// Parse partial JSON into a Map with a function object
fn parse_partial_json(partial: &str) -> Result<Map<String, Value>, String> {
let (completed, was_quote_open) = complete_json(partial);
match serde_json::from_str::<Value>(&completed) {
Ok(Value::Object(obj)) => {
if let Some(Value::Object(function)) = obj.get("function") {
let name_is_only_key = function.len() == 1;
if was_quote_open && name_is_only_key {
let mut function = function.clone();
if let Some(Value::String(ref mut name)) = function.get_mut("_name") {
name.clear();
}
return Err("Missing *name in function".to_string());
}
Ok(function.clone())
} else {
Err("Missing function object".to_string())
}
}
_ => Err("Failed to parse as object".to_string()),
}
}
/// Creates an event based on the token text and event type parameters.
/// `token_text` - The text to include (extract from StreamResponse.token.text or str)
/// `model_id` - Model identifier string
/// `system_fingerprint` - System fingerprint string
/// `tool_name` - If provided, creates a tool call name event
/// `is_tool_arg` - If true, creates a tool call argument event
fn create_event(
token_text: &str,
model_id: &str,
system_fingerprint: &str,
tool_name: Option<&str>,
is_tool_arg: bool,
finish_reason: Option<String>,
) -> Event { ) -> Event {
let event = Event::default();
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_default()
.as_secs(); .as_secs();
let logprobs = logprobs.then(|| { let chat_complete = if let Some(tool_name) = tool_name {
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone())) // Tool call name event
}); let tool_delta = ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
tool_calls: vec![DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: Some(tool_name.to_string()),
arguments: "".to_string(),
},
}],
});
// replace the content with the tool calls if grammar is present CompletionType::ChatCompletionChunk(ChatCompletionChunk {
let (content, tool_calls) = if inner_using_tools { id: String::new(),
(None, Some(vec![stream_token.token.text.clone()])) created: current_time,
model: model_id.to_string(),
system_fingerprint: system_fingerprint.to_string(),
choices: vec![ChatCompletionChoice {
index: 0,
delta: tool_delta,
logprobs: None,
finish_reason: None,
}],
usage: None,
})
} else if is_tool_arg {
// Tool call argument event
let tool_delta = ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
tool_calls: vec![DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: token_text.to_string(),
},
}],
});
CompletionType::ChatCompletionChunk(ChatCompletionChunk {
id: String::new(),
created: current_time,
model: model_id.to_string(),
system_fingerprint: system_fingerprint.to_string(),
choices: vec![ChatCompletionChoice {
index: 0,
delta: tool_delta,
logprobs: None,
finish_reason: None,
}],
usage: None,
})
} else { } else {
let content = if !stream_token.token.special { // usage, finish_reason
Some(stream_token.token.text.clone()) if finish_reason.is_some() {
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.to_string(),
system_fingerprint.to_string(),
Some(token_text.to_string()),
None,
current_time,
None,
finish_reason,
None,
None,
))
} else { } else {
None // Chat completion event
}; CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.to_string(),
(content, None) system_fingerprint.to_string(),
}; Some(token_text.to_string()),
None,
let (usage, finish_reason) = match &stream_token.details { current_time,
Some(details) => { None,
let usage = if stream_options None,
.as_ref() None,
.map(|s| s.include_usage) None,
.unwrap_or(false) ))
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Some(Usage {
completion_tokens,
prompt_tokens,
total_tokens,
})
} else {
None
};
(usage, Some(details.finish_reason.format(true)))
} }
None => (None, None),
}; };
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( Event::default()
model_id.clone(), .json_data(chat_complete)
system_fingerprint.clone(), .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
content,
tool_calls,
current_time,
logprobs,
finish_reason,
usage,
tool_name,
));
event.json_data(chat_complete).unwrap_or_else(|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
})
} }
/// Generate tokens /// Generate tokens
@ -1239,13 +1338,12 @@ pub(crate) async fn chat_completions(
let ChatRequest { let ChatRequest {
model, model,
stream, stream,
stream_options,
logprobs, logprobs,
// TODO: add back and maybe consolidate the other PR
// stream_options,
.. ..
} = chat.clone(); } = chat.clone();
let (generate_request, using_tools): (GenerateRequest, bool) = let (generate_request, using_tools) = chat.try_into_generate(&infer)?;
chat.try_into_generate(&infer)?;
let logprobs = logprobs.unwrap_or_default(); let logprobs = logprobs.unwrap_or_default();
// extract model id from request if specified // extract model id from request if specified
@ -1254,210 +1352,112 @@ pub(crate) async fn chat_completions(
Some(m_id) => m_id.to_string(), Some(m_id) => m_id.to_string(),
}; };
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// switch on stream
if stream { if stream {
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), span).await; generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
Ok(regex) => regex,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to compile regex: {}", e),
error_type: "regex".to_string(),
}),
))
}
};
let response_stream = async_stream::stream! { let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
let mut buffer = Vec::new();
let mut json_buffer = String::new(); let mut json_buffer = String::new();
let mut state = if using_tools { let mut name_found = !using_tools;
StreamState::Buffering let mut no_tool_chosen = false;
} else { let mut first_quote_removed = false;
StreamState::Content {
skip_close_quote: false,
}
};
let mut response_as_tool = using_tools;
while let Some(result) = response_stream.next().await { while let Some(result) = response_stream.next().await {
match result{ match result {
Ok(stream_token) => { Ok(stream_token) => {
let token_text = &stream_token.token.text.clone(); let token_text = stream_token.token.text.clone();
match state { json_buffer.push_str(&token_text);
StreamState::Buffering => { if !name_found {
json_buffer.push_str(&token_text.replace(" ", "")); // since we know tools is attempting to follow a grammar we can attempt to
buffer.push(stream_token); // partially parse the json_buffer to see if we can extract the function name
if let Some(captures) = function_regex.captures(&json_buffer) { if let Ok(function) = parse_partial_json(&json_buffer) {
let function_name = captures[1].to_string(); let name = function.get("_name").and_then(|n| n.as_str()).unwrap_or("no_tool");
if function_name == "no_tool" { name_found = true;
state = StreamState::BufferTrailing; if name == "no_tool" {
response_as_tool = false; no_tool_chosen = true;
buffer.clear();
json_buffer.clear(); json_buffer.clear();
json_buffer.push('{');
} else { } else {
state = StreamState::Content { let tool_name_event = create_event(&token_text, &model_id, &system_fingerprint, Some(name), false, None);
skip_close_quote: false, yield Ok::<Event, Infallible>(tool_name_event);
}; let tool_open_arguments_event = create_event("{", &model_id, &system_fingerprint, None, true, None);
let event = Event::default(); yield Ok::<Event, Infallible>(tool_open_arguments_event);
let current_time = std::time::SystemTime::now() // clear the buffer as we know that the buffer is only the function
.duration_since(std::time::UNIX_EPOCH) // ie: ` {"function": {"_name": "get_current_weather",` -> `{"`
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) // we need to keep the `{` to open the arguments and allow the parser to continue
.as_secs(); json_buffer.clear();
let tool_delta_start = ChatCompletionDelta::Tool(ToolCallDelta { json_buffer.push('{');
role: "assistant".to_string(),
tool_calls: vec![DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: Some(function_name.clone()),
arguments: "".to_string(),
},
}],
});
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk{
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![ChatCompletionChoice {
index: 0,
delta: tool_delta_start,
logprobs: None,
finish_reason: None,
}],
usage: None,
});
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
buffer.drain(1..); // only keep the first token (opening '{')
buffer[0].token.text = buffer[0].token.text.chars().take(1).collect();
} }
} }
} } else {
// if we skipped sending the buffer we need to avoid sending the following json key and quotes // Process JSON buffer and handle token text
StreamState::BufferTrailing => { let last_is_brace = json_buffer.ends_with('}');
let infix_text = "\"content\":\""; let edited_buffer = if last_is_brace { &json_buffer[..json_buffer.len() - 1] } else { &json_buffer };
json_buffer.push_str(&token_text.replace(" ", "")); let mut token_text = stream_token.token.text.clone();
// keep capturing until we find the infix text let is_json_complete = serde_json::from_str::<Value>(edited_buffer).is_ok();
match json_buffer.find(infix_text) {
Some(content_key_index) => { // Handle tool usage cases
json_buffer = if using_tools {
json_buffer[content_key_index + infix_text.len()..].to_string(); if no_tool_chosen {
} // Tool without selection ("content" flow)
None => { if let Ok(function) = parse_generic_structure(edited_buffer) {
if function.get("content").and_then(|c| c.as_str()).is_some() {
// Handle quotation marks
if !first_quote_removed {
first_quote_removed = true;
if token_text == "\"" || token_text == " \"" { continue; }
token_text = token_text.replace("\"", "");
} else if token_text.ends_with('"') {
token_text = token_text[..token_text.len() - 1].to_string();
}
if is_json_complete { break; }
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
continue;
}
}
continue;
} else {
// Tool with selection
if is_json_complete {
// Final token with possible brace removal
if last_is_brace { token_text = token_text[..token_text.len() - 1].to_string(); }
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
break;
}
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
continue; continue;
} }
} } else {
// if there is leftover text after removing the infix text, we need to send it // Default case: standard chat completion
if !json_buffer.is_empty() { if let Some(details) = stream_token.details.as_ref() {
let event = Event::default(); // Handle final token and only send text if ended because of length
let current_time = std::time::SystemTime::now() let text = if details.finish_reason == FinishReason::Length { &token_text } else { "" };
.duration_since(std::time::UNIX_EPOCH) yield Ok::<Event, Infallible>(create_event(text, &model_id, &system_fingerprint, None, false, Some(details.finish_reason.format(true))));
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) break;
.as_secs();
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
Some(json_buffer.clone()),
None,
current_time,
None,
None,
None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
}
buffer.push(stream_token);
if buffer.len() > 1 {
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
for stream_token in &buffer[..buffer.len() - 2] {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
None,
);
yield Ok::<Event, Infallible>(event);
} }
buffer = buffer.drain(buffer.len() - 2..).collect(); yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
} }
} }
} }
} Err(err) => yield Ok(err.into_openai_event()),
Err(err) => yield Ok(err.into_openai_event())
} }
} }
if response_as_tool {
// send the second to last stream token but remove the trailing '}' if it exists
let mut closing_stream_token = buffer.remove(0);
closing_stream_token.token.text = closing_stream_token.token.text.strip_suffix("}").unwrap_or(&closing_stream_token.token.text).to_string();
let event = create_event_from_stream_token(
&closing_stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
None,
);
yield Ok::<Event, Infallible>(event);
} else {
// send each buffer element
for stream_token in buffer {
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
None,
);
yield Ok::<Event, Infallible>(event);
}
}
yield Ok::<Event, Infallible>(Event::default().data("[DONE]")); yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
}; };
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
// Non-streaming case
let (headers, input_length, Json(generation)) = let (headers, input_length, Json(generation)) =
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?; generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_default()
.as_secs(); .as_secs();
let (tool_calls, output) = if using_tools { let (tool_calls, output) = if using_tools {
@ -1490,11 +1490,9 @@ pub(crate) async fn chat_completions(
let content_message = arguments let content_message = arguments
.get("content") .get("content")
.and_then(Value::as_str) .and_then(Value::as_str)
.ok_or_else(|| { .ok_or(InferError::ToolError(
InferError::ToolError( "No `content` found in generated text".to_string(),
"No `content` found in generated text".to_string(), ))?
)
})?
.to_string(); .to_string();
(None, Some(content_message)) (None, Some(content_message))
} }