feat: refactor chat stream to remove state machine and simplfy logic

This commit is contained in:
drbh 2025-02-24 21:51:33 +00:00
parent a416ddbdd9
commit 31a536d796
5 changed files with 255 additions and 258 deletions

View File

@ -2,7 +2,7 @@
"choices": [
{
"delta": {
"content": " assistant",
"content": "!",
"role": "assistant",
"tool_calls": null
},
@ -11,7 +11,7 @@
"logprobs": null
}
],
"created": 1739441937,
"created": 1740432006,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -2,7 +2,7 @@
"choices": [
{
"delta": {
"content": " Oracle",
"content": ".",
"role": "assistant",
"tool_calls": null
},
@ -11,7 +11,7 @@
"logprobs": null
}
],
"created": 1739444803,
"created": 1740432012,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -11,7 +11,7 @@
"logprobs": null
}
],
"created": 1739454835,
"created": 1740433572,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -281,8 +281,8 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 5
assert content_generated == "I am a helpful assistant"
assert count == 6
assert content_generated == "I am a helpful assistant!"
assert last_response == response_snapshot
@ -318,10 +318,10 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream(
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 77
assert count == 78
assert (
content_generated
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle"
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle."
)
assert last_response == response_snapshot
@ -401,7 +401,6 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_none(
assert response.choices[0].delta.tool_calls is None
assert count == 100
print(content_generated)
assert (
content_generated
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep"

View File

@ -47,7 +47,7 @@ use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use regex::Regex;
use serde_json::Map;
use serde_json::Value;
use std::convert::Infallible;
use std::fs::File;
@ -1114,84 +1114,183 @@ pub(crate) async fn completions(
}
}
enum StreamState {
Buffering,
BufferTrailing,
Content { skip_close_quote: bool },
// balance the started json with closing braces and quotes
fn complete_json(partial: &str) -> (String, bool) {
let mut brace_count = 0;
let mut quote_open = false;
let mut escaped = false;
let mut last_char = '\0';
for c in partial.chars() {
match (escaped, quote_open, c) {
(true, _, _) => escaped = false,
(false, _, '\\') => escaped = true,
(false, _, '"') => quote_open = !quote_open,
(false, false, '{') => brace_count += 1,
(false, false, '}') => brace_count -= 1,
_ => {}
}
if !c.is_whitespace() {
last_char = c;
}
}
let mut completed = partial.to_string();
if last_char == ',' {
if let Some(pos) = completed.rfind(',') {
completed.replace_range(pos..pos + 1, "");
}
}
if quote_open {
completed.push('"');
}
completed.push_str(&"}".repeat(brace_count.max(0)));
(completed, quote_open)
}
/// Convert a StreamResponse into an Event to be sent over SSE
fn create_event_from_stream_token(
stream_token: &StreamResponse,
logprobs: bool,
stream_options: Option<StreamOptions>,
inner_using_tools: bool,
system_fingerprint: String,
model_id: String,
tool_name: Option<String>,
// Generic function that parses any partial structure into a Map
fn parse_generic_structure(partial: &str) -> Result<Map<String, Value>, String> {
let (completed, _) = complete_json(partial);
match serde_json::from_str::<Value>(&completed) {
Ok(Value::Object(obj)) => Ok(obj),
_ => Err("Failed to parse as object".to_string()),
}
}
// Parse partial JSON into a Map with a function object
fn parse_partial_json(partial: &str) -> Result<Map<String, Value>, String> {
let (completed, was_quote_open) = complete_json(partial);
match serde_json::from_str::<Value>(&completed) {
Ok(Value::Object(obj)) => {
if let Some(Value::Object(function)) = obj.get("function") {
let name_is_only_key = function.len() == 1;
if was_quote_open && name_is_only_key {
let mut function = function.clone();
if let Some(Value::String(ref mut name)) = function.get_mut("_name") {
name.clear();
}
return Err("Missing *name in function".to_string());
}
Ok(function.clone())
} else {
Err("Missing function object".to_string())
}
}
_ => Err("Failed to parse as object".to_string()),
}
}
/// Creates an event based on the token text and event type parameters.
/// `token_text` - The text to include (extract from StreamResponse.token.text or str)
/// `model_id` - Model identifier string
/// `system_fingerprint` - System fingerprint string
/// `tool_name` - If provided, creates a tool call name event
/// `is_tool_arg` - If true, creates a tool call argument event
fn create_event(
token_text: &str,
model_id: &str,
system_fingerprint: &str,
tool_name: Option<&str>,
is_tool_arg: bool,
finish_reason: Option<String>,
) -> Event {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.unwrap_or_default()
.as_secs();
let logprobs = logprobs.then(|| {
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
});
let chat_complete = if let Some(tool_name) = tool_name {
// Tool call name event
let tool_delta = ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
tool_calls: vec![DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: Some(tool_name.to_string()),
arguments: "".to_string(),
},
}],
});
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools {
(None, Some(vec![stream_token.token.text.clone()]))
CompletionType::ChatCompletionChunk(ChatCompletionChunk {
id: String::new(),
created: current_time,
model: model_id.to_string(),
system_fingerprint: system_fingerprint.to_string(),
choices: vec![ChatCompletionChoice {
index: 0,
delta: tool_delta,
logprobs: None,
finish_reason: None,
}],
usage: None,
})
} else if is_tool_arg {
// Tool call argument event
let tool_delta = ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
tool_calls: vec![DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: token_text.to_string(),
},
}],
});
CompletionType::ChatCompletionChunk(ChatCompletionChunk {
id: String::new(),
created: current_time,
model: model_id.to_string(),
system_fingerprint: system_fingerprint.to_string(),
choices: vec![ChatCompletionChoice {
index: 0,
delta: tool_delta,
logprobs: None,
finish_reason: None,
}],
usage: None,
})
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text.clone())
// usage, finish_reason
if finish_reason.is_some() {
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.to_string(),
system_fingerprint.to_string(),
Some(token_text.to_string()),
None,
current_time,
None,
finish_reason,
None,
None,
))
} else {
None
};
(content, None)
};
let (usage, finish_reason) = match &stream_token.details {
Some(details) => {
let usage = if stream_options
.as_ref()
.map(|s| s.include_usage)
.unwrap_or(false)
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Some(Usage {
completion_tokens,
prompt_tokens,
total_tokens,
})
} else {
None
};
(usage, Some(details.finish_reason.format(true)))
// Chat completion event
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.to_string(),
system_fingerprint.to_string(),
Some(token_text.to_string()),
None,
current_time,
None,
None,
None,
None,
))
}
None => (None, None),
};
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
content,
tool_calls,
current_time,
logprobs,
finish_reason,
usage,
tool_name,
));
event.json_data(chat_complete).unwrap_or_else(|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
})
Event::default()
.json_data(chat_complete)
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
}
/// Generate tokens
@ -1239,13 +1338,12 @@ pub(crate) async fn chat_completions(
let ChatRequest {
model,
stream,
stream_options,
logprobs,
// TODO: add back and maybe consolidate the other PR
// stream_options,
..
} = chat.clone();
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.try_into_generate(&infer)?;
let (generate_request, using_tools) = chat.try_into_generate(&infer)?;
let logprobs = logprobs.unwrap_or_default();
// extract model id from request if specified
@ -1254,210 +1352,112 @@ pub(crate) async fn chat_completions(
Some(m_id) => m_id.to_string(),
};
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// switch on stream
if stream {
let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
Ok(regex) => regex,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to compile regex: {}", e),
error_type: "regex".to_string(),
}),
))
}
};
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
let mut buffer = Vec::new();
let mut json_buffer = String::new();
let mut state = if using_tools {
StreamState::Buffering
} else {
StreamState::Content {
skip_close_quote: false,
}
};
let mut response_as_tool = using_tools;
let mut name_found = !using_tools;
let mut no_tool_chosen = false;
let mut first_quote_removed = false;
while let Some(result) = response_stream.next().await {
match result{
Ok(stream_token) => {
let token_text = &stream_token.token.text.clone();
match state {
StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "no_tool" {
state = StreamState::BufferTrailing;
response_as_tool = false;
buffer.clear();
match result {
Ok(stream_token) => {
let token_text = stream_token.token.text.clone();
json_buffer.push_str(&token_text);
if !name_found {
// since we know tools is attempting to follow a grammar we can attempt to
// partially parse the json_buffer to see if we can extract the function name
if let Ok(function) = parse_partial_json(&json_buffer) {
let name = function.get("_name").and_then(|n| n.as_str()).unwrap_or("no_tool");
name_found = true;
if name == "no_tool" {
no_tool_chosen = true;
json_buffer.clear();
json_buffer.push('{');
} else {
state = StreamState::Content {
skip_close_quote: false,
};
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let tool_delta_start = ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
tool_calls: vec![DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: Some(function_name.clone()),
arguments: "".to_string(),
},
}],
});
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk{
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![ChatCompletionChoice {
index: 0,
delta: tool_delta_start,
logprobs: None,
finish_reason: None,
}],
usage: None,
});
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
buffer.drain(1..); // only keep the first token (opening '{')
buffer[0].token.text = buffer[0].token.text.chars().take(1).collect();
let tool_name_event = create_event(&token_text, &model_id, &system_fingerprint, Some(name), false, None);
yield Ok::<Event, Infallible>(tool_name_event);
let tool_open_arguments_event = create_event("{", &model_id, &system_fingerprint, None, true, None);
yield Ok::<Event, Infallible>(tool_open_arguments_event);
// clear the buffer as we know that the buffer is only the function
// ie: ` {"function": {"_name": "get_current_weather",` -> `{"`
// we need to keep the `{` to open the arguments and allow the parser to continue
json_buffer.clear();
json_buffer.push('{');
}
}
}
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
StreamState::BufferTrailing => {
let infix_text = "\"content\":\"";
json_buffer.push_str(&token_text.replace(" ", ""));
// keep capturing until we find the infix text
match json_buffer.find(infix_text) {
Some(content_key_index) => {
json_buffer =
json_buffer[content_key_index + infix_text.len()..].to_string();
}
None => {
} else {
// Process JSON buffer and handle token text
let last_is_brace = json_buffer.ends_with('}');
let edited_buffer = if last_is_brace { &json_buffer[..json_buffer.len() - 1] } else { &json_buffer };
let mut token_text = stream_token.token.text.clone();
let is_json_complete = serde_json::from_str::<Value>(edited_buffer).is_ok();
// Handle tool usage cases
if using_tools {
if no_tool_chosen {
// Tool without selection ("content" flow)
if let Ok(function) = parse_generic_structure(edited_buffer) {
if function.get("content").and_then(|c| c.as_str()).is_some() {
// Handle quotation marks
if !first_quote_removed {
first_quote_removed = true;
if token_text == "\"" || token_text == " \"" { continue; }
token_text = token_text.replace("\"", "");
} else if token_text.ends_with('"') {
token_text = token_text[..token_text.len() - 1].to_string();
}
if is_json_complete { break; }
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
continue;
}
}
continue;
} else {
// Tool with selection
if is_json_complete {
// Final token with possible brace removal
if last_is_brace { token_text = token_text[..token_text.len() - 1].to_string(); }
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
break;
}
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
continue;
}
}
// if there is leftover text after removing the infix text, we need to send it
if !json_buffer.is_empty() {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
Some(json_buffer.clone()),
None,
current_time,
None,
None,
None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
}
buffer.push(stream_token);
if buffer.len() > 1 {
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
for stream_token in &buffer[..buffer.len() - 2] {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
None,
);
yield Ok::<Event, Infallible>(event);
} else {
// Default case: standard chat completion
if let Some(details) = stream_token.details.as_ref() {
// Handle final token and only send text if ended because of length
let text = if details.finish_reason == FinishReason::Length { &token_text } else { "" };
yield Ok::<Event, Infallible>(create_event(text, &model_id, &system_fingerprint, None, false, Some(details.finish_reason.format(true))));
break;
}
buffer = buffer.drain(buffer.len() - 2..).collect();
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
}
}
}
}
Err(err) => yield Ok(err.into_openai_event())
Err(err) => yield Ok(err.into_openai_event()),
}
}
if response_as_tool {
// send the second to last stream token but remove the trailing '}' if it exists
let mut closing_stream_token = buffer.remove(0);
closing_stream_token.token.text = closing_stream_token.token.text.strip_suffix("}").unwrap_or(&closing_stream_token.token.text).to_string();
let event = create_event_from_stream_token(
&closing_stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
None,
);
yield Ok::<Event, Infallible>(event);
} else {
// send each buffer element
for stream_token in buffer {
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
None,
);
yield Ok::<Event, Infallible>(event);
}
}
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
};
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
// Non-streaming case
let (headers, input_length, Json(generation)) =
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.unwrap_or_default()
.as_secs();
let (tool_calls, output) = if using_tools {
@ -1490,11 +1490,9 @@ pub(crate) async fn chat_completions(
let content_message = arguments
.get("content")
.and_then(Value::as_str)
.ok_or_else(|| {
InferError::ToolError(
"No `content` found in generated text".to_string(),
)
})?
.ok_or(InferError::ToolError(
"No `content` found in generated text".to_string(),
))?
.to_string();
(None, Some(content_message))
}