feat: return streaming errors as an event formatted for openai's client

This commit is contained in:
drbh 2024-10-18 14:15:27 -04:00 committed by Nicolas Patry
parent 34a3bdedc3
commit 84cd8434b0
No known key found for this signature in database
GPG Key ID: D2920555C90F704C

View File

@ -1274,99 +1274,108 @@ pub(crate) async fn chat_completions(
}; };
let mut response_as_tool = using_tools; let mut response_as_tool = using_tools;
while let Some(result) = response_stream.next().await { while let Some(result) = response_stream.next().await {
if let Ok(stream_token) = result { match result {
let token_text = &stream_token.token.text.clone(); Ok(stream_token) => {
match state { let token_text = &stream_token.token.text.clone();
StreamState::Buffering => { match state {
json_buffer.push_str(&token_text.replace(" ", "")); StreamState::Buffering => {
buffer.push(stream_token); json_buffer.push_str(&token_text.replace(" ", ""));
if let Some(captures) = function_regex.captures(&json_buffer) { buffer.push(stream_token);
let function_name = captures[1].to_string(); if let Some(captures) = function_regex.captures(&json_buffer) {
if function_name == "no_tool" { let function_name = captures[1].to_string();
state = StreamState::BufferTrailing; if function_name == "no_tool" {
response_as_tool = false; state = StreamState::BufferTrailing;
buffer.clear(); response_as_tool = false;
json_buffer.clear(); buffer.clear();
} else { json_buffer.clear();
state = StreamState::Content { } else {
skip_close_quote: false, state = StreamState::Content {
}; skip_close_quote: false,
// send all the buffered messages };
for stream_token in &buffer { // send all the buffered messages
let event = create_event_from_stream_token( for stream_token in &buffer {
stream_token, let event = create_event_from_stream_token(
logprobs, stream_token,
stream_options.clone(), logprobs,
response_as_tool, stream_options.clone(),
system_fingerprint.clone(), response_as_tool,
model_id.clone(), system_fingerprint.clone(),
); model_id.clone(),
yield Ok::<Event, Infallible>(event); );
yield Ok::<Event, Infallible>(event);
}
} }
} }
} }
} // if we skipped sending the buffer we need to avoid sending the following json key and quotes
// if we skipped sending the buffer we need to avoid sending the following json key and quotes StreamState::BufferTrailing => {
StreamState::BufferTrailing => { let infix_text = "\"content\":\"";
let infix_text = "\"content\":\""; json_buffer.push_str(&token_text.replace(" ", ""));
json_buffer.push_str(&token_text.replace(" ", "")); // keep capturing until we find the infix text
// keep capturing until we find the infix text match json_buffer.find(infix_text) {
match json_buffer.find(infix_text) { Some(content_key_index) => {
Some(content_key_index) => { json_buffer =
json_buffer = json_buffer[content_key_index + infix_text.len()..].to_string();
json_buffer[content_key_index + infix_text.len()..].to_string(); }
None => {
continue;
}
} }
None => { // if there is leftover text after removing the infix text, we need to send it
continue; if !json_buffer.is_empty() {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
Some(json_buffer.clone()),
None,
current_time,
None,
None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
} }
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
} }
// if there is leftover text after removing the infix text, we need to send it StreamState::Content { skip_close_quote } => {
if !json_buffer.is_empty() { if skip_close_quote && token_text.contains('"') {
let event = Event::default(); break;
let current_time = std::time::SystemTime::now() }
.duration_since(std::time::UNIX_EPOCH) // send the content
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) let event = create_event_from_stream_token(
.as_secs(); &stream_token,
let chat_complete = logprobs,
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( stream_options.clone(),
model_id.clone(), response_as_tool,
system_fingerprint.clone(), system_fingerprint.clone(),
Some(json_buffer.clone()), model_id.clone(),
None, );
current_time,
None,
None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
}
// send the content yield Ok::<Event, Infallible>(event);
let event = create_event_from_stream_token( }
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
} }
} }
Err(err) => {
let error_event: ErrorEvent = err.into();
let event = Event::default().json_data(error_event).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
});
yield Ok::<Event, Infallible>(event);
break;
}
} }
} }
yield Ok::<Event, Infallible>(Event::default().data("[DONE]")); yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
@ -2517,6 +2526,26 @@ impl From<InferError> for Event {
} }
} }
#[derive(serde::Serialize)]
pub struct ErrorWithMessage {
message: String,
}
#[derive(serde::Serialize)]
pub struct ErrorEvent {
error: ErrorWithMessage,
}
impl From<InferError> for ErrorEvent {
fn from(err: InferError) -> Self {
ErrorEvent {
error: ErrorWithMessage {
message: err.to_string(),
},
}
}
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum WebServerError { pub enum WebServerError {
#[error("Axum error: {0}")] #[error("Axum error: {0}")]