mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 19:02:09 +00:00
feat: refactor chat stream to remove state machine and simplfy logic
This commit is contained in:
parent
a416ddbdd9
commit
31a536d796
@ -2,7 +2,7 @@
|
|||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"delta": {
|
"delta": {
|
||||||
"content": " assistant",
|
"content": "!",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
},
|
},
|
||||||
@ -11,7 +11,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1739441937,
|
"created": 1740432006,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"delta": {
|
"delta": {
|
||||||
"content": " Oracle",
|
"content": ".",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
},
|
},
|
||||||
@ -11,7 +11,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1739444803,
|
"created": 1740432012,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1739454835,
|
"created": 1740433572,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
@ -281,8 +281,8 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
|
|||||||
last_response = response
|
last_response = response
|
||||||
assert response.choices[0].delta.tool_calls is None
|
assert response.choices[0].delta.tool_calls is None
|
||||||
|
|
||||||
assert count == 5
|
assert count == 6
|
||||||
assert content_generated == "I am a helpful assistant"
|
assert content_generated == "I am a helpful assistant!"
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -318,10 +318,10 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream(
|
|||||||
last_response = response
|
last_response = response
|
||||||
assert response.choices[0].delta.tool_calls is None
|
assert response.choices[0].delta.tool_calls is None
|
||||||
|
|
||||||
assert count == 77
|
assert count == 78
|
||||||
assert (
|
assert (
|
||||||
content_generated
|
content_generated
|
||||||
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle"
|
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle."
|
||||||
)
|
)
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
@ -401,7 +401,6 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_none(
|
|||||||
assert response.choices[0].delta.tool_calls is None
|
assert response.choices[0].delta.tool_calls is None
|
||||||
|
|
||||||
assert count == 100
|
assert count == 100
|
||||||
print(content_generated)
|
|
||||||
assert (
|
assert (
|
||||||
content_generated
|
content_generated
|
||||||
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep"
|
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep"
|
||||||
|
@ -47,7 +47,7 @@ use http::header::AUTHORIZATION;
|
|||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::IntoPyDict;
|
use pyo3::types::IntoPyDict;
|
||||||
use regex::Regex;
|
use serde_json::Map;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
@ -1114,84 +1114,183 @@ pub(crate) async fn completions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
enum StreamState {
|
// balance the started json with closing braces and quotes
|
||||||
Buffering,
|
fn complete_json(partial: &str) -> (String, bool) {
|
||||||
BufferTrailing,
|
let mut brace_count = 0;
|
||||||
Content { skip_close_quote: bool },
|
let mut quote_open = false;
|
||||||
|
let mut escaped = false;
|
||||||
|
let mut last_char = '\0';
|
||||||
|
|
||||||
|
for c in partial.chars() {
|
||||||
|
match (escaped, quote_open, c) {
|
||||||
|
(true, _, _) => escaped = false,
|
||||||
|
(false, _, '\\') => escaped = true,
|
||||||
|
(false, _, '"') => quote_open = !quote_open,
|
||||||
|
(false, false, '{') => brace_count += 1,
|
||||||
|
(false, false, '}') => brace_count -= 1,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
if !c.is_whitespace() {
|
||||||
|
last_char = c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut completed = partial.to_string();
|
||||||
|
|
||||||
|
if last_char == ',' {
|
||||||
|
if let Some(pos) = completed.rfind(',') {
|
||||||
|
completed.replace_range(pos..pos + 1, "");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if quote_open {
|
||||||
|
completed.push('"');
|
||||||
|
}
|
||||||
|
completed.push_str(&"}".repeat(brace_count.max(0)));
|
||||||
|
|
||||||
|
(completed, quote_open)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert a StreamResponse into an Event to be sent over SSE
|
// Generic function that parses any partial structure into a Map
|
||||||
fn create_event_from_stream_token(
|
fn parse_generic_structure(partial: &str) -> Result<Map<String, Value>, String> {
|
||||||
stream_token: &StreamResponse,
|
let (completed, _) = complete_json(partial);
|
||||||
logprobs: bool,
|
match serde_json::from_str::<Value>(&completed) {
|
||||||
stream_options: Option<StreamOptions>,
|
Ok(Value::Object(obj)) => Ok(obj),
|
||||||
inner_using_tools: bool,
|
_ => Err("Failed to parse as object".to_string()),
|
||||||
system_fingerprint: String,
|
}
|
||||||
model_id: String,
|
}
|
||||||
tool_name: Option<String>,
|
|
||||||
|
// Parse partial JSON into a Map with a function object
|
||||||
|
fn parse_partial_json(partial: &str) -> Result<Map<String, Value>, String> {
|
||||||
|
let (completed, was_quote_open) = complete_json(partial);
|
||||||
|
match serde_json::from_str::<Value>(&completed) {
|
||||||
|
Ok(Value::Object(obj)) => {
|
||||||
|
if let Some(Value::Object(function)) = obj.get("function") {
|
||||||
|
let name_is_only_key = function.len() == 1;
|
||||||
|
if was_quote_open && name_is_only_key {
|
||||||
|
let mut function = function.clone();
|
||||||
|
if let Some(Value::String(ref mut name)) = function.get_mut("_name") {
|
||||||
|
name.clear();
|
||||||
|
}
|
||||||
|
return Err("Missing *name in function".to_string());
|
||||||
|
}
|
||||||
|
Ok(function.clone())
|
||||||
|
} else {
|
||||||
|
Err("Missing function object".to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err("Failed to parse as object".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an event based on the token text and event type parameters.
|
||||||
|
/// `token_text` - The text to include (extract from StreamResponse.token.text or str)
|
||||||
|
/// `model_id` - Model identifier string
|
||||||
|
/// `system_fingerprint` - System fingerprint string
|
||||||
|
/// `tool_name` - If provided, creates a tool call name event
|
||||||
|
/// `is_tool_arg` - If true, creates a tool call argument event
|
||||||
|
fn create_event(
|
||||||
|
token_text: &str,
|
||||||
|
model_id: &str,
|
||||||
|
system_fingerprint: &str,
|
||||||
|
tool_name: Option<&str>,
|
||||||
|
is_tool_arg: bool,
|
||||||
|
finish_reason: Option<String>,
|
||||||
) -> Event {
|
) -> Event {
|
||||||
let event = Event::default();
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_default()
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
let logprobs = logprobs.then(|| {
|
let chat_complete = if let Some(tool_name) = tool_name {
|
||||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
|
// Tool call name event
|
||||||
});
|
let tool_delta = ChatCompletionDelta::Tool(ToolCallDelta {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
tool_calls: vec![DeltaToolCall {
|
||||||
|
index: 0,
|
||||||
|
id: String::new(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
name: Some(tool_name.to_string()),
|
||||||
|
arguments: "".to_string(),
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
|
||||||
// replace the content with the tool calls if grammar is present
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
||||||
let (content, tool_calls) = if inner_using_tools {
|
id: String::new(),
|
||||||
(None, Some(vec![stream_token.token.text.clone()]))
|
created: current_time,
|
||||||
|
model: model_id.to_string(),
|
||||||
|
system_fingerprint: system_fingerprint.to_string(),
|
||||||
|
choices: vec![ChatCompletionChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: tool_delta,
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: None,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
})
|
||||||
|
} else if is_tool_arg {
|
||||||
|
// Tool call argument event
|
||||||
|
let tool_delta = ChatCompletionDelta::Tool(ToolCallDelta {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
tool_calls: vec![DeltaToolCall {
|
||||||
|
index: 0,
|
||||||
|
id: String::new(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
name: None,
|
||||||
|
arguments: token_text.to_string(),
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
|
||||||
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
||||||
|
id: String::new(),
|
||||||
|
created: current_time,
|
||||||
|
model: model_id.to_string(),
|
||||||
|
system_fingerprint: system_fingerprint.to_string(),
|
||||||
|
choices: vec![ChatCompletionChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: tool_delta,
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: None,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
let content = if !stream_token.token.special {
|
// usage, finish_reason
|
||||||
Some(stream_token.token.text.clone())
|
if finish_reason.is_some() {
|
||||||
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||||
|
model_id.to_string(),
|
||||||
|
system_fingerprint.to_string(),
|
||||||
|
Some(token_text.to_string()),
|
||||||
|
None,
|
||||||
|
current_time,
|
||||||
|
None,
|
||||||
|
finish_reason,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
None
|
// Chat completion event
|
||||||
};
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||||
|
model_id.to_string(),
|
||||||
(content, None)
|
system_fingerprint.to_string(),
|
||||||
};
|
Some(token_text.to_string()),
|
||||||
|
None,
|
||||||
let (usage, finish_reason) = match &stream_token.details {
|
current_time,
|
||||||
Some(details) => {
|
None,
|
||||||
let usage = if stream_options
|
None,
|
||||||
.as_ref()
|
None,
|
||||||
.map(|s| s.include_usage)
|
None,
|
||||||
.unwrap_or(false)
|
))
|
||||||
{
|
|
||||||
let completion_tokens = details.generated_tokens;
|
|
||||||
let prompt_tokens = details.input_length;
|
|
||||||
let total_tokens = prompt_tokens + completion_tokens;
|
|
||||||
Some(Usage {
|
|
||||||
completion_tokens,
|
|
||||||
prompt_tokens,
|
|
||||||
total_tokens,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
(usage, Some(details.finish_reason.format(true)))
|
|
||||||
}
|
}
|
||||||
None => (None, None),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
Event::default()
|
||||||
model_id.clone(),
|
.json_data(chat_complete)
|
||||||
system_fingerprint.clone(),
|
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
|
||||||
content,
|
|
||||||
tool_calls,
|
|
||||||
current_time,
|
|
||||||
logprobs,
|
|
||||||
finish_reason,
|
|
||||||
usage,
|
|
||||||
tool_name,
|
|
||||||
));
|
|
||||||
|
|
||||||
event.json_data(chat_complete).unwrap_or_else(|e| {
|
|
||||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
|
||||||
Event::default()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
@ -1239,13 +1338,12 @@ pub(crate) async fn chat_completions(
|
|||||||
let ChatRequest {
|
let ChatRequest {
|
||||||
model,
|
model,
|
||||||
stream,
|
stream,
|
||||||
stream_options,
|
|
||||||
logprobs,
|
logprobs,
|
||||||
|
// TODO: add back and maybe consolidate the other PR
|
||||||
|
// stream_options,
|
||||||
..
|
..
|
||||||
} = chat.clone();
|
} = chat.clone();
|
||||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
let (generate_request, using_tools) = chat.try_into_generate(&infer)?;
|
||||||
chat.try_into_generate(&infer)?;
|
|
||||||
|
|
||||||
let logprobs = logprobs.unwrap_or_default();
|
let logprobs = logprobs.unwrap_or_default();
|
||||||
|
|
||||||
// extract model id from request if specified
|
// extract model id from request if specified
|
||||||
@ -1254,210 +1352,112 @@ pub(crate) async fn chat_completions(
|
|||||||
Some(m_id) => m_id.to_string(),
|
Some(m_id) => m_id.to_string(),
|
||||||
};
|
};
|
||||||
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||||
// switch on stream
|
|
||||||
if stream {
|
if stream {
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||||
|
|
||||||
// regex to match any function name
|
|
||||||
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
|
|
||||||
Ok(regex) => regex,
|
|
||||||
Err(e) => {
|
|
||||||
return Err((
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: format!("Failed to compile regex: {}", e),
|
|
||||||
error_type: "regex".to_string(),
|
|
||||||
}),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let response_stream = async_stream::stream! {
|
let response_stream = async_stream::stream! {
|
||||||
let mut response_stream = Box::pin(response_stream);
|
let mut response_stream = Box::pin(response_stream);
|
||||||
let mut buffer = Vec::new();
|
|
||||||
let mut json_buffer = String::new();
|
let mut json_buffer = String::new();
|
||||||
let mut state = if using_tools {
|
let mut name_found = !using_tools;
|
||||||
StreamState::Buffering
|
let mut no_tool_chosen = false;
|
||||||
} else {
|
let mut first_quote_removed = false;
|
||||||
StreamState::Content {
|
|
||||||
skip_close_quote: false,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let mut response_as_tool = using_tools;
|
|
||||||
while let Some(result) = response_stream.next().await {
|
while let Some(result) = response_stream.next().await {
|
||||||
match result{
|
match result {
|
||||||
Ok(stream_token) => {
|
Ok(stream_token) => {
|
||||||
let token_text = &stream_token.token.text.clone();
|
let token_text = stream_token.token.text.clone();
|
||||||
match state {
|
json_buffer.push_str(&token_text);
|
||||||
StreamState::Buffering => {
|
if !name_found {
|
||||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
// since we know tools is attempting to follow a grammar we can attempt to
|
||||||
buffer.push(stream_token);
|
// partially parse the json_buffer to see if we can extract the function name
|
||||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
if let Ok(function) = parse_partial_json(&json_buffer) {
|
||||||
let function_name = captures[1].to_string();
|
let name = function.get("_name").and_then(|n| n.as_str()).unwrap_or("no_tool");
|
||||||
if function_name == "no_tool" {
|
name_found = true;
|
||||||
state = StreamState::BufferTrailing;
|
if name == "no_tool" {
|
||||||
response_as_tool = false;
|
no_tool_chosen = true;
|
||||||
buffer.clear();
|
|
||||||
json_buffer.clear();
|
json_buffer.clear();
|
||||||
|
json_buffer.push('{');
|
||||||
} else {
|
} else {
|
||||||
state = StreamState::Content {
|
let tool_name_event = create_event(&token_text, &model_id, &system_fingerprint, Some(name), false, None);
|
||||||
skip_close_quote: false,
|
yield Ok::<Event, Infallible>(tool_name_event);
|
||||||
};
|
let tool_open_arguments_event = create_event("{", &model_id, &system_fingerprint, None, true, None);
|
||||||
let event = Event::default();
|
yield Ok::<Event, Infallible>(tool_open_arguments_event);
|
||||||
let current_time = std::time::SystemTime::now()
|
// clear the buffer as we know that the buffer is only the function
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
// ie: ` {"function": {"_name": "get_current_weather",` -> `{"`
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
// we need to keep the `{` to open the arguments and allow the parser to continue
|
||||||
.as_secs();
|
json_buffer.clear();
|
||||||
let tool_delta_start = ChatCompletionDelta::Tool(ToolCallDelta {
|
json_buffer.push('{');
|
||||||
role: "assistant".to_string(),
|
|
||||||
tool_calls: vec![DeltaToolCall {
|
|
||||||
index: 0,
|
|
||||||
id: String::new(),
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: Function {
|
|
||||||
name: Some(function_name.clone()),
|
|
||||||
arguments: "".to_string(),
|
|
||||||
},
|
|
||||||
}],
|
|
||||||
});
|
|
||||||
let chat_complete =
|
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk{
|
|
||||||
id: String::new(),
|
|
||||||
created: current_time,
|
|
||||||
model: model_id.clone(),
|
|
||||||
system_fingerprint: system_fingerprint.clone(),
|
|
||||||
choices: vec![ChatCompletionChoice {
|
|
||||||
index: 0,
|
|
||||||
delta: tool_delta_start,
|
|
||||||
logprobs: None,
|
|
||||||
finish_reason: None,
|
|
||||||
}],
|
|
||||||
usage: None,
|
|
||||||
});
|
|
||||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
|
||||||
InferError::StreamSerializationError(e.to_string()).into()
|
|
||||||
}));
|
|
||||||
buffer.drain(1..); // only keep the first token (opening '{')
|
|
||||||
buffer[0].token.text = buffer[0].token.text.chars().take(1).collect();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
// Process JSON buffer and handle token text
|
||||||
StreamState::BufferTrailing => {
|
let last_is_brace = json_buffer.ends_with('}');
|
||||||
let infix_text = "\"content\":\"";
|
let edited_buffer = if last_is_brace { &json_buffer[..json_buffer.len() - 1] } else { &json_buffer };
|
||||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
let mut token_text = stream_token.token.text.clone();
|
||||||
// keep capturing until we find the infix text
|
let is_json_complete = serde_json::from_str::<Value>(edited_buffer).is_ok();
|
||||||
match json_buffer.find(infix_text) {
|
|
||||||
Some(content_key_index) => {
|
// Handle tool usage cases
|
||||||
json_buffer =
|
if using_tools {
|
||||||
json_buffer[content_key_index + infix_text.len()..].to_string();
|
if no_tool_chosen {
|
||||||
}
|
// Tool without selection ("content" flow)
|
||||||
None => {
|
if let Ok(function) = parse_generic_structure(edited_buffer) {
|
||||||
|
if function.get("content").and_then(|c| c.as_str()).is_some() {
|
||||||
|
// Handle quotation marks
|
||||||
|
if !first_quote_removed {
|
||||||
|
first_quote_removed = true;
|
||||||
|
if token_text == "\"" || token_text == " \"" { continue; }
|
||||||
|
token_text = token_text.replace("\"", "");
|
||||||
|
} else if token_text.ends_with('"') {
|
||||||
|
token_text = token_text[..token_text.len() - 1].to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_json_complete { break; }
|
||||||
|
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
// Tool with selection
|
||||||
|
if is_json_complete {
|
||||||
|
// Final token with possible brace removal
|
||||||
|
if last_is_brace { token_text = token_text[..token_text.len() - 1].to_string(); }
|
||||||
|
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
// if there is leftover text after removing the infix text, we need to send it
|
// Default case: standard chat completion
|
||||||
if !json_buffer.is_empty() {
|
if let Some(details) = stream_token.details.as_ref() {
|
||||||
let event = Event::default();
|
// Handle final token and only send text if ended because of length
|
||||||
let current_time = std::time::SystemTime::now()
|
let text = if details.finish_reason == FinishReason::Length { &token_text } else { "" };
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
yield Ok::<Event, Infallible>(create_event(text, &model_id, &system_fingerprint, None, false, Some(details.finish_reason.format(true))));
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
break;
|
||||||
.as_secs();
|
|
||||||
let chat_complete =
|
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
|
||||||
model_id.clone(),
|
|
||||||
system_fingerprint.clone(),
|
|
||||||
Some(json_buffer.clone()),
|
|
||||||
None,
|
|
||||||
current_time,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
));
|
|
||||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
|
||||||
InferError::StreamSerializationError(e.to_string()).into()
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
// cleanup the buffers
|
|
||||||
buffer.clear();
|
|
||||||
json_buffer.clear();
|
|
||||||
state = StreamState::Content {
|
|
||||||
skip_close_quote: true,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
StreamState::Content { skip_close_quote } => {
|
|
||||||
if skip_close_quote && token_text.contains('"') {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer.push(stream_token);
|
|
||||||
if buffer.len() > 1 {
|
|
||||||
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
|
|
||||||
for stream_token in &buffer[..buffer.len() - 2] {
|
|
||||||
let event = create_event_from_stream_token(
|
|
||||||
stream_token,
|
|
||||||
logprobs,
|
|
||||||
stream_options.clone(),
|
|
||||||
response_as_tool,
|
|
||||||
system_fingerprint.clone(),
|
|
||||||
model_id.clone(),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
|
||||||
}
|
}
|
||||||
buffer = buffer.drain(buffer.len() - 2..).collect();
|
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
Err(err) => yield Ok(err.into_openai_event()),
|
||||||
Err(err) => yield Ok(err.into_openai_event())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if response_as_tool {
|
|
||||||
// send the second to last stream token but remove the trailing '}' if it exists
|
|
||||||
let mut closing_stream_token = buffer.remove(0);
|
|
||||||
closing_stream_token.token.text = closing_stream_token.token.text.strip_suffix("}").unwrap_or(&closing_stream_token.token.text).to_string();
|
|
||||||
let event = create_event_from_stream_token(
|
|
||||||
&closing_stream_token,
|
|
||||||
logprobs,
|
|
||||||
stream_options.clone(),
|
|
||||||
response_as_tool,
|
|
||||||
system_fingerprint.clone(),
|
|
||||||
model_id.clone(),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
|
||||||
} else {
|
|
||||||
// send each buffer element
|
|
||||||
for stream_token in buffer {
|
|
||||||
let event = create_event_from_stream_token(
|
|
||||||
&stream_token,
|
|
||||||
logprobs,
|
|
||||||
stream_options.clone(),
|
|
||||||
response_as_tool,
|
|
||||||
system_fingerprint.clone(),
|
|
||||||
model_id.clone(),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||||
};
|
};
|
||||||
|
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
|
// Non-streaming case
|
||||||
let (headers, input_length, Json(generation)) =
|
let (headers, input_length, Json(generation)) =
|
||||||
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_default()
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
let (tool_calls, output) = if using_tools {
|
let (tool_calls, output) = if using_tools {
|
||||||
@ -1490,11 +1490,9 @@ pub(crate) async fn chat_completions(
|
|||||||
let content_message = arguments
|
let content_message = arguments
|
||||||
.get("content")
|
.get("content")
|
||||||
.and_then(Value::as_str)
|
.and_then(Value::as_str)
|
||||||
.ok_or_else(|| {
|
.ok_or(InferError::ToolError(
|
||||||
InferError::ToolError(
|
"No `content` found in generated text".to_string(),
|
||||||
"No `content` found in generated text".to_string(),
|
))?
|
||||||
)
|
|
||||||
})?
|
|
||||||
.to_string();
|
.to_string();
|
||||||
(None, Some(content_message))
|
(None, Some(content_message))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user