Reworked the implementation.

This commit is contained in:
Nicolas Patry 2024-11-15 20:24:47 +07:00
parent 22d205aa47
commit df72deac26
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
2 changed files with 114 additions and 126 deletions

View File

@ -14,6 +14,7 @@ use chat_template::ChatTemplate;
use futures::future::try_join_all; use futures::future::try_join_all;
use futures::Stream; use futures::Stream;
use minijinja::ErrorKind; use minijinja::ErrorKind;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
@ -373,4 +374,25 @@ impl InferError {
InferError::StreamSerializationError(_) => "stream_serialization_error", InferError::StreamSerializationError(_) => "stream_serialization_error",
} }
} }
pub(crate) fn into_openai_event(self) -> Event {
let message = self.to_string();
Event::default().json_data(OpenaiErrorEvent {
error: APIError {
message,
http_status_code: 422,
},
})
}
}
#[derive(Serialize)]
pub struct APIError {
message: String,
http_status_code: usize,
}
#[derive(Serialize)]
pub struct OpenaiErrorEvent {
error: APIError,
} }

View File

@ -866,14 +866,7 @@ pub(crate) async fn completions(
yield Ok(event); yield Ok(event);
} }
Err(err) => { Err(err) => yield Ok(err.into_openai_event()),
let event = Event::default()
.json_data(ErrorEvent::into_api_error(err, 422))
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into());
println!("{:?}", event);
yield Ok::<Event, Infallible>(event);
break
}
} }
} }
}; };
@ -1281,107 +1274,102 @@ pub(crate) async fn chat_completions(
}; };
let mut response_as_tool = using_tools; let mut response_as_tool = using_tools;
while let Some(result) = response_stream.next().await { while let Some(result) = response_stream.next().await {
match result { match result{
Ok(stream_token) => { Ok(stream_tokens) => {
let token_text = &stream_token.token.text.clone(); let token_text = &stream_token.token.text.clone();
match state { match state {
StreamState::Buffering => { StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", "")); json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token); buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) { if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string(); let function_name = captures[1].to_string();
if function_name == "no_tool" { if function_name == "no_tool" {
state = StreamState::BufferTrailing; state = StreamState::BufferTrailing;
response_as_tool = false; response_as_tool = false;
buffer.clear(); buffer.clear();
json_buffer.clear(); json_buffer.clear();
} else { } else {
state = StreamState::Content { state = StreamState::Content {
skip_close_quote: false, skip_close_quote: false,
}; };
// send all the buffered messages // send all the buffered messages
for stream_token in &buffer { for stream_token in &buffer {
let event = create_event_from_stream_token( let event = create_event_from_stream_token(
stream_token, stream_token,
logprobs, logprobs,
stream_options.clone(), stream_options.clone(),
response_as_tool, response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
}
}
}
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
StreamState::BufferTrailing => {
let infix_text = "\"content\":\"";
json_buffer.push_str(&token_text.replace(" ", ""));
// keep capturing until we find the infix text
match json_buffer.find(infix_text) {
Some(content_key_index) => {
json_buffer =
json_buffer[content_key_index + infix_text.len()..].to_string();
}
None => {
continue;
}
}
// if there is leftover text after removing the infix text, we need to send it
if !json_buffer.is_empty() {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(), system_fingerprint.clone(),
Some(json_buffer.clone()), model_id.clone(),
None, );
current_time, yield Ok::<Event, Infallible>(event);
None, }
None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
} }
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
}
// send the content
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
} }
} }
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
StreamState::BufferTrailing => {
let infix_text = "\"content\":\"";
json_buffer.push_str(&token_text.replace(" ", ""));
// keep capturing until we find the infix text
match json_buffer.find(infix_text) {
Some(content_key_index) => {
json_buffer =
json_buffer[content_key_index + infix_text.len()..].to_string();
}
None => {
continue;
}
}
// if there is leftover text after removing the infix text, we need to send it
if !json_buffer.is_empty() {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
Some(json_buffer.clone()),
None,
current_time,
None,
None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
}
// send the content
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
} }
Err(err) => { },
let event = Event::default() Err(err) => yield Event::from_openai(err)
.json_data(ErrorEvent::into_api_error(err, 422))
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into());
yield Ok::<Event, Infallible>(event);
break;
}
} }
} }
yield Ok::<Event, Infallible>(Event::default().data("[DONE]")); yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
@ -2532,28 +2520,6 @@ impl From<InferError> for Event {
} }
} }
#[derive(serde::Serialize)]
pub struct APIError {
message: String,
http_status_code: usize,
}
#[derive(serde::Serialize)]
pub struct ErrorEvent {
error: APIError,
}
impl ErrorEvent {
fn into_api_error(err: InferError, http_status_code: usize) -> Self {
ErrorEvent {
error: APIError {
message: err.to_string(),
http_status_code,
},
}
}
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum WebServerError { pub enum WebServerError {
#[error("Axum error: {0}")] #[error("Axum error: {0}")]