Reworked the implementation.

This commit is contained in:
Nicolas Patry 2024-11-15 20:24:47 +07:00
parent 22d205aa47
commit df72deac26
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
2 changed files with 114 additions and 126 deletions

View File

@ -14,6 +14,7 @@ use chat_template::ChatTemplate;
use futures::future::try_join_all;
use futures::Stream;
use minijinja::ErrorKind;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use thiserror::Error;
@ -373,4 +374,25 @@ impl InferError {
InferError::StreamSerializationError(_) => "stream_serialization_error",
}
}
pub(crate) fn into_openai_event(self) -> Event {
let message = self.to_string();
Event::default().json_data(OpenaiErrorEvent {
error: APIError {
message,
http_status_code: 422,
},
})
}
}
#[derive(Serialize)]
pub struct APIError {
message: String,
http_status_code: usize,
}
#[derive(Serialize)]
pub struct OpenaiErrorEvent {
error: APIError,
}

View File

@ -866,14 +866,7 @@ pub(crate) async fn completions(
yield Ok(event);
}
Err(err) => {
let event = Event::default()
.json_data(ErrorEvent::into_api_error(err, 422))
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into());
println!("{:?}", event);
yield Ok::<Event, Infallible>(event);
break
}
Err(err) => yield Ok(err.into_openai_event()),
}
}
};
@ -1282,7 +1275,7 @@ pub(crate) async fn chat_completions(
let mut response_as_tool = using_tools;
while let Some(result) = response_stream.next().await {
match result{
Ok(stream_token) => {
Ok(stream_tokens) => {
let token_text = &stream_token.token.text.clone();
match state {
StreamState::Buffering => {
@ -1361,6 +1354,7 @@ pub(crate) async fn chat_completions(
if skip_close_quote && token_text.contains('"') {
break;
}
// send the content
let event = create_event_from_stream_token(
&stream_token,
@ -1374,14 +1368,8 @@ pub(crate) async fn chat_completions(
yield Ok::<Event, Infallible>(event);
}
}
}
Err(err) => {
let event = Event::default()
.json_data(ErrorEvent::into_api_error(err, 422))
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into());
yield Ok::<Event, Infallible>(event);
break;
}
},
Err(err) => yield Event::from_openai(err)
}
}
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
@ -2532,28 +2520,6 @@ impl From<InferError> for Event {
}
}
#[derive(serde::Serialize)]
pub struct APIError {
message: String,
http_status_code: usize,
}
#[derive(serde::Serialize)]
pub struct ErrorEvent {
error: APIError,
}
impl ErrorEvent {
fn into_api_error(err: InferError, http_status_code: usize) -> Self {
ErrorEvent {
error: APIError {
message: err.to_string(),
http_status_code,
},
}
}
}
#[derive(Debug, Error)]
pub enum WebServerError {
#[error("Axum error: {0}")]