feat: return streaming errors as an event formatted for openai's client

This commit is contained in:
drbh 2024-10-18 14:15:27 -04:00 committed by Nicolas Patry
parent 34a3bdedc3
commit 84cd8434b0
No known key found for this signature in database
GPG Key ID: D2920555C90F704C

View File

@ -1274,7 +1274,8 @@ pub(crate) async fn chat_completions(
}; };
let mut response_as_tool = using_tools; let mut response_as_tool = using_tools;
while let Some(result) = response_stream.next().await { while let Some(result) = response_stream.next().await {
if let Ok(stream_token) = result { match result {
Ok(stream_token) => {
let token_text = &stream_token.token.text.clone(); let token_text = &stream_token.token.text.clone();
match state { match state {
StreamState::Buffering => { StreamState::Buffering => {
@ -1353,7 +1354,6 @@ pub(crate) async fn chat_completions(
if skip_close_quote && token_text.contains('"') { if skip_close_quote && token_text.contains('"') {
break; break;
} }
// send the content // send the content
let event = create_event_from_stream_token( let event = create_event_from_stream_token(
&stream_token, &stream_token,
@ -1368,6 +1368,15 @@ pub(crate) async fn chat_completions(
} }
} }
} }
Err(err) => {
let error_event: ErrorEvent = err.into();
let event = Event::default().json_data(error_event).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
});
yield Ok::<Event, Infallible>(event);
break;
}
}
} }
yield Ok::<Event, Infallible>(Event::default().data("[DONE]")); yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
}; };
@ -2517,6 +2526,26 @@ impl From<InferError> for Event {
} }
} }
#[derive(serde::Serialize)]
pub struct ErrorWithMessage {
message: String,
}
#[derive(serde::Serialize)]
pub struct ErrorEvent {
error: ErrorWithMessage,
}
impl From<InferError> for ErrorEvent {
fn from(err: InferError) -> Self {
ErrorEvent {
error: ErrorWithMessage {
message: err.to_string(),
},
}
}
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum WebServerError { pub enum WebServerError {
#[error("Axum error: {0}")] #[error("Axum error: {0}")]