mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 10:52:07 +00:00
feat: refactor chat stream to remove state machine and simplfy logic
This commit is contained in:
parent
a416ddbdd9
commit
31a536d796
@ -2,7 +2,7 @@
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " assistant",
|
||||
"content": "!",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
@ -11,7 +11,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1739441937,
|
||||
"created": 1740432006,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
|
@ -2,7 +2,7 @@
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " Oracle",
|
||||
"content": ".",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
@ -11,7 +11,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1739444803,
|
||||
"created": 1740432012,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
|
@ -11,7 +11,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1739454835,
|
||||
"created": 1740433572,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
|
@ -281,8 +281,8 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
|
||||
last_response = response
|
||||
assert response.choices[0].delta.tool_calls is None
|
||||
|
||||
assert count == 5
|
||||
assert content_generated == "I am a helpful assistant"
|
||||
assert count == 6
|
||||
assert content_generated == "I am a helpful assistant!"
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@ -318,10 +318,10 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream(
|
||||
last_response = response
|
||||
assert response.choices[0].delta.tool_calls is None
|
||||
|
||||
assert count == 77
|
||||
assert count == 78
|
||||
assert (
|
||||
content_generated
|
||||
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle"
|
||||
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle."
|
||||
)
|
||||
assert last_response == response_snapshot
|
||||
|
||||
@ -401,7 +401,6 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_none(
|
||||
assert response.choices[0].delta.tool_calls is None
|
||||
|
||||
assert count == 100
|
||||
print(content_generated)
|
||||
assert (
|
||||
content_generated
|
||||
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep"
|
||||
|
@ -47,7 +47,7 @@ use http::header::AUTHORIZATION;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::IntoPyDict;
|
||||
use regex::Regex;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::fs::File;
|
||||
@ -1114,84 +1114,183 @@ pub(crate) async fn completions(
|
||||
}
|
||||
}
|
||||
|
||||
enum StreamState {
|
||||
Buffering,
|
||||
BufferTrailing,
|
||||
Content { skip_close_quote: bool },
|
||||
// balance the started json with closing braces and quotes
|
||||
fn complete_json(partial: &str) -> (String, bool) {
|
||||
let mut brace_count = 0;
|
||||
let mut quote_open = false;
|
||||
let mut escaped = false;
|
||||
let mut last_char = '\0';
|
||||
|
||||
for c in partial.chars() {
|
||||
match (escaped, quote_open, c) {
|
||||
(true, _, _) => escaped = false,
|
||||
(false, _, '\\') => escaped = true,
|
||||
(false, _, '"') => quote_open = !quote_open,
|
||||
(false, false, '{') => brace_count += 1,
|
||||
(false, false, '}') => brace_count -= 1,
|
||||
_ => {}
|
||||
}
|
||||
if !c.is_whitespace() {
|
||||
last_char = c;
|
||||
}
|
||||
}
|
||||
|
||||
let mut completed = partial.to_string();
|
||||
|
||||
if last_char == ',' {
|
||||
if let Some(pos) = completed.rfind(',') {
|
||||
completed.replace_range(pos..pos + 1, "");
|
||||
}
|
||||
}
|
||||
|
||||
if quote_open {
|
||||
completed.push('"');
|
||||
}
|
||||
completed.push_str(&"}".repeat(brace_count.max(0)));
|
||||
|
||||
(completed, quote_open)
|
||||
}
|
||||
|
||||
/// Convert a StreamResponse into an Event to be sent over SSE
|
||||
fn create_event_from_stream_token(
|
||||
stream_token: &StreamResponse,
|
||||
logprobs: bool,
|
||||
stream_options: Option<StreamOptions>,
|
||||
inner_using_tools: bool,
|
||||
system_fingerprint: String,
|
||||
model_id: String,
|
||||
tool_name: Option<String>,
|
||||
// Generic function that parses any partial structure into a Map
|
||||
fn parse_generic_structure(partial: &str) -> Result<Map<String, Value>, String> {
|
||||
let (completed, _) = complete_json(partial);
|
||||
match serde_json::from_str::<Value>(&completed) {
|
||||
Ok(Value::Object(obj)) => Ok(obj),
|
||||
_ => Err("Failed to parse as object".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
// Parse partial JSON into a Map with a function object
|
||||
fn parse_partial_json(partial: &str) -> Result<Map<String, Value>, String> {
|
||||
let (completed, was_quote_open) = complete_json(partial);
|
||||
match serde_json::from_str::<Value>(&completed) {
|
||||
Ok(Value::Object(obj)) => {
|
||||
if let Some(Value::Object(function)) = obj.get("function") {
|
||||
let name_is_only_key = function.len() == 1;
|
||||
if was_quote_open && name_is_only_key {
|
||||
let mut function = function.clone();
|
||||
if let Some(Value::String(ref mut name)) = function.get_mut("_name") {
|
||||
name.clear();
|
||||
}
|
||||
return Err("Missing *name in function".to_string());
|
||||
}
|
||||
Ok(function.clone())
|
||||
} else {
|
||||
Err("Missing function object".to_string())
|
||||
}
|
||||
}
|
||||
_ => Err("Failed to parse as object".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an event based on the token text and event type parameters.
|
||||
/// `token_text` - The text to include (extract from StreamResponse.token.text or str)
|
||||
/// `model_id` - Model identifier string
|
||||
/// `system_fingerprint` - System fingerprint string
|
||||
/// `tool_name` - If provided, creates a tool call name event
|
||||
/// `is_tool_arg` - If true, creates a tool call argument event
|
||||
fn create_event(
|
||||
token_text: &str,
|
||||
model_id: &str,
|
||||
system_fingerprint: &str,
|
||||
tool_name: Option<&str>,
|
||||
is_tool_arg: bool,
|
||||
finish_reason: Option<String>,
|
||||
) -> Event {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let logprobs = logprobs.then(|| {
|
||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
|
||||
});
|
||||
let chat_complete = if let Some(tool_name) = tool_name {
|
||||
// Tool call name event
|
||||
let tool_delta = ChatCompletionDelta::Tool(ToolCallDelta {
|
||||
role: "assistant".to_string(),
|
||||
tool_calls: vec![DeltaToolCall {
|
||||
index: 0,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: Some(tool_name.to_string()),
|
||||
arguments: "".to_string(),
|
||||
},
|
||||
}],
|
||||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if inner_using_tools {
|
||||
(None, Some(vec![stream_token.token.text.clone()]))
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: model_id.to_string(),
|
||||
system_fingerprint: system_fingerprint.to_string(),
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
delta: tool_delta,
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
}],
|
||||
usage: None,
|
||||
})
|
||||
} else if is_tool_arg {
|
||||
// Tool call argument event
|
||||
let tool_delta = ChatCompletionDelta::Tool(ToolCallDelta {
|
||||
role: "assistant".to_string(),
|
||||
tool_calls: vec![DeltaToolCall {
|
||||
index: 0,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: None,
|
||||
arguments: token_text.to_string(),
|
||||
},
|
||||
}],
|
||||
});
|
||||
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: model_id.to_string(),
|
||||
system_fingerprint: system_fingerprint.to_string(),
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
delta: tool_delta,
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
}],
|
||||
usage: None,
|
||||
})
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text.clone())
|
||||
// usage, finish_reason
|
||||
if finish_reason.is_some() {
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.to_string(),
|
||||
system_fingerprint.to_string(),
|
||||
Some(token_text.to_string()),
|
||||
None,
|
||||
current_time,
|
||||
None,
|
||||
finish_reason,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(content, None)
|
||||
};
|
||||
|
||||
let (usage, finish_reason) = match &stream_token.details {
|
||||
Some(details) => {
|
||||
let usage = if stream_options
|
||||
.as_ref()
|
||||
.map(|s| s.include_usage)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
Some(Usage {
|
||||
completion_tokens,
|
||||
prompt_tokens,
|
||||
total_tokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(usage, Some(details.finish_reason.format(true)))
|
||||
// Chat completion event
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.to_string(),
|
||||
system_fingerprint.to_string(),
|
||||
Some(token_text.to_string()),
|
||||
None,
|
||||
current_time,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
None => (None, None),
|
||||
};
|
||||
|
||||
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
content,
|
||||
tool_calls,
|
||||
current_time,
|
||||
logprobs,
|
||||
finish_reason,
|
||||
usage,
|
||||
tool_name,
|
||||
));
|
||||
|
||||
event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
})
|
||||
Event::default()
|
||||
.json_data(chat_complete)
|
||||
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
@ -1239,13 +1338,12 @@ pub(crate) async fn chat_completions(
|
||||
let ChatRequest {
|
||||
model,
|
||||
stream,
|
||||
stream_options,
|
||||
logprobs,
|
||||
// TODO: add back and maybe consolidate the other PR
|
||||
// stream_options,
|
||||
..
|
||||
} = chat.clone();
|
||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||
chat.try_into_generate(&infer)?;
|
||||
|
||||
let (generate_request, using_tools) = chat.try_into_generate(&infer)?;
|
||||
let logprobs = logprobs.unwrap_or_default();
|
||||
|
||||
// extract model id from request if specified
|
||||
@ -1254,210 +1352,112 @@ pub(crate) async fn chat_completions(
|
||||
Some(m_id) => m_id.to_string(),
|
||||
};
|
||||
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||
// switch on stream
|
||||
|
||||
if stream {
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||
|
||||
// regex to match any function name
|
||||
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
|
||||
Ok(regex) => regex,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: format!("Failed to compile regex: {}", e),
|
||||
error_type: "regex".to_string(),
|
||||
}),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
let mut buffer = Vec::new();
|
||||
let mut json_buffer = String::new();
|
||||
let mut state = if using_tools {
|
||||
StreamState::Buffering
|
||||
} else {
|
||||
StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
}
|
||||
};
|
||||
let mut response_as_tool = using_tools;
|
||||
let mut name_found = !using_tools;
|
||||
let mut no_tool_chosen = false;
|
||||
let mut first_quote_removed = false;
|
||||
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result{
|
||||
Ok(stream_token) => {
|
||||
let token_text = &stream_token.token.text.clone();
|
||||
match state {
|
||||
StreamState::Buffering => {
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
buffer.push(stream_token);
|
||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||
let function_name = captures[1].to_string();
|
||||
if function_name == "no_tool" {
|
||||
state = StreamState::BufferTrailing;
|
||||
response_as_tool = false;
|
||||
buffer.clear();
|
||||
match result {
|
||||
Ok(stream_token) => {
|
||||
let token_text = stream_token.token.text.clone();
|
||||
json_buffer.push_str(&token_text);
|
||||
if !name_found {
|
||||
// since we know tools is attempting to follow a grammar we can attempt to
|
||||
// partially parse the json_buffer to see if we can extract the function name
|
||||
if let Ok(function) = parse_partial_json(&json_buffer) {
|
||||
let name = function.get("_name").and_then(|n| n.as_str()).unwrap_or("no_tool");
|
||||
name_found = true;
|
||||
if name == "no_tool" {
|
||||
no_tool_chosen = true;
|
||||
json_buffer.clear();
|
||||
json_buffer.push('{');
|
||||
} else {
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
};
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
let tool_delta_start = ChatCompletionDelta::Tool(ToolCallDelta {
|
||||
role: "assistant".to_string(),
|
||||
tool_calls: vec![DeltaToolCall {
|
||||
index: 0,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: Some(function_name.clone()),
|
||||
arguments: "".to_string(),
|
||||
},
|
||||
}],
|
||||
});
|
||||
let chat_complete =
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk{
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
delta: tool_delta_start,
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
}],
|
||||
usage: None,
|
||||
});
|
||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
InferError::StreamSerializationError(e.to_string()).into()
|
||||
}));
|
||||
buffer.drain(1..); // only keep the first token (opening '{')
|
||||
buffer[0].token.text = buffer[0].token.text.chars().take(1).collect();
|
||||
let tool_name_event = create_event(&token_text, &model_id, &system_fingerprint, Some(name), false, None);
|
||||
yield Ok::<Event, Infallible>(tool_name_event);
|
||||
let tool_open_arguments_event = create_event("{", &model_id, &system_fingerprint, None, true, None);
|
||||
yield Ok::<Event, Infallible>(tool_open_arguments_event);
|
||||
// clear the buffer as we know that the buffer is only the function
|
||||
// ie: ` {"function": {"_name": "get_current_weather",` -> `{"`
|
||||
// we need to keep the `{` to open the arguments and allow the parser to continue
|
||||
json_buffer.clear();
|
||||
json_buffer.push('{');
|
||||
}
|
||||
}
|
||||
}
|
||||
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
||||
StreamState::BufferTrailing => {
|
||||
let infix_text = "\"content\":\"";
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
// keep capturing until we find the infix text
|
||||
match json_buffer.find(infix_text) {
|
||||
Some(content_key_index) => {
|
||||
json_buffer =
|
||||
json_buffer[content_key_index + infix_text.len()..].to_string();
|
||||
}
|
||||
None => {
|
||||
} else {
|
||||
// Process JSON buffer and handle token text
|
||||
let last_is_brace = json_buffer.ends_with('}');
|
||||
let edited_buffer = if last_is_brace { &json_buffer[..json_buffer.len() - 1] } else { &json_buffer };
|
||||
let mut token_text = stream_token.token.text.clone();
|
||||
let is_json_complete = serde_json::from_str::<Value>(edited_buffer).is_ok();
|
||||
|
||||
// Handle tool usage cases
|
||||
if using_tools {
|
||||
if no_tool_chosen {
|
||||
// Tool without selection ("content" flow)
|
||||
if let Ok(function) = parse_generic_structure(edited_buffer) {
|
||||
if function.get("content").and_then(|c| c.as_str()).is_some() {
|
||||
// Handle quotation marks
|
||||
if !first_quote_removed {
|
||||
first_quote_removed = true;
|
||||
if token_text == "\"" || token_text == " \"" { continue; }
|
||||
token_text = token_text.replace("\"", "");
|
||||
} else if token_text.ends_with('"') {
|
||||
token_text = token_text[..token_text.len() - 1].to_string();
|
||||
}
|
||||
|
||||
if is_json_complete { break; }
|
||||
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
// Tool with selection
|
||||
if is_json_complete {
|
||||
// Final token with possible brace removal
|
||||
if last_is_brace { token_text = token_text[..token_text.len() - 1].to_string(); }
|
||||
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
|
||||
break;
|
||||
}
|
||||
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, true, None));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// if there is leftover text after removing the infix text, we need to send it
|
||||
if !json_buffer.is_empty() {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
let chat_complete =
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
Some(json_buffer.clone()),
|
||||
None,
|
||||
current_time,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
));
|
||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
InferError::StreamSerializationError(e.to_string()).into()
|
||||
}));
|
||||
}
|
||||
// cleanup the buffers
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: true,
|
||||
};
|
||||
}
|
||||
StreamState::Content { skip_close_quote } => {
|
||||
if skip_close_quote && token_text.contains('"') {
|
||||
break;
|
||||
}
|
||||
|
||||
buffer.push(stream_token);
|
||||
if buffer.len() > 1 {
|
||||
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
|
||||
for stream_token in &buffer[..buffer.len() - 2] {
|
||||
let event = create_event_from_stream_token(
|
||||
stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
None,
|
||||
);
|
||||
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
} else {
|
||||
// Default case: standard chat completion
|
||||
if let Some(details) = stream_token.details.as_ref() {
|
||||
// Handle final token and only send text if ended because of length
|
||||
let text = if details.finish_reason == FinishReason::Length { &token_text } else { "" };
|
||||
yield Ok::<Event, Infallible>(create_event(text, &model_id, &system_fingerprint, None, false, Some(details.finish_reason.format(true))));
|
||||
break;
|
||||
}
|
||||
buffer = buffer.drain(buffer.len() - 2..).collect();
|
||||
yield Ok::<Event, Infallible>(create_event(&token_text, &model_id, &system_fingerprint, None, false, None));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => yield Ok(err.into_openai_event())
|
||||
Err(err) => yield Ok(err.into_openai_event()),
|
||||
}
|
||||
}
|
||||
if response_as_tool {
|
||||
// send the second to last stream token but remove the trailing '}' if it exists
|
||||
let mut closing_stream_token = buffer.remove(0);
|
||||
closing_stream_token.token.text = closing_stream_token.token.text.strip_suffix("}").unwrap_or(&closing_stream_token.token.text).to_string();
|
||||
let event = create_event_from_stream_token(
|
||||
&closing_stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
None,
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
} else {
|
||||
// send each buffer element
|
||||
for stream_token in buffer {
|
||||
let event = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
None,
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
|
||||
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||
};
|
||||
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
// Non-streaming case
|
||||
let (headers, input_length, Json(generation)) =
|
||||
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let (tool_calls, output) = if using_tools {
|
||||
@ -1490,11 +1490,9 @@ pub(crate) async fn chat_completions(
|
||||
let content_message = arguments
|
||||
.get("content")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| {
|
||||
InferError::ToolError(
|
||||
"No `content` found in generated text".to_string(),
|
||||
)
|
||||
})?
|
||||
.ok_or(InferError::ToolError(
|
||||
"No `content` found in generated text".to_string(),
|
||||
))?
|
||||
.to_string();
|
||||
(None, Some(content_message))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user