From 84cd8434b076734b091172ab32c38e106e1cc388 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 18 Oct 2024 14:15:27 -0400 Subject: [PATCH] feat: return streaming errors as an event formatted for openai's client --- router/src/server.rs | 197 +++++++++++++++++++++++++------------------ 1 file changed, 113 insertions(+), 84 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index a0bc1768..bfdd6a8b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1274,99 +1274,108 @@ pub(crate) async fn chat_completions( }; let mut response_as_tool = using_tools; while let Some(result) = response_stream.next().await { - if let Ok(stream_token) = result { - let token_text = &stream_token.token.text.clone(); - match state { - StreamState::Buffering => { - json_buffer.push_str(&token_text.replace(" ", "")); - buffer.push(stream_token); - if let Some(captures) = function_regex.captures(&json_buffer) { - let function_name = captures[1].to_string(); - if function_name == "no_tool" { - state = StreamState::BufferTrailing; - response_as_tool = false; - buffer.clear(); - json_buffer.clear(); - } else { - state = StreamState::Content { - skip_close_quote: false, - }; - // send all the buffered messages - for stream_token in &buffer { - let event = create_event_from_stream_token( - stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - yield Ok::(event); + match result { + Ok(stream_token) => { + let token_text = &stream_token.token.text.clone(); + match state { + StreamState::Buffering => { + json_buffer.push_str(&token_text.replace(" ", "")); + buffer.push(stream_token); + if let Some(captures) = function_regex.captures(&json_buffer) { + let function_name = captures[1].to_string(); + if function_name == "no_tool" { + state = StreamState::BufferTrailing; + response_as_tool = false; + buffer.clear(); + json_buffer.clear(); + } else { + state = StreamState::Content { + skip_close_quote: false, + }; + // send all the buffered messages + for stream_token in &buffer { + let event = create_event_from_stream_token( + stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + yield Ok::(event); + } } } } - } - // if we skipped sending the buffer we need to avoid sending the following json key and quotes - StreamState::BufferTrailing => { - let infix_text = "\"content\":\""; - json_buffer.push_str(&token_text.replace(" ", "")); - // keep capturing until we find the infix text - match json_buffer.find(infix_text) { - Some(content_key_index) => { - json_buffer = - json_buffer[content_key_index + infix_text.len()..].to_string(); + // if we skipped sending the buffer we need to avoid sending the following json key and quotes + StreamState::BufferTrailing => { + let infix_text = "\"content\":\""; + json_buffer.push_str(&token_text.replace(" ", "")); + // keep capturing until we find the infix text + match json_buffer.find(infix_text) { + Some(content_key_index) => { + json_buffer = + json_buffer[content_key_index + infix_text.len()..].to_string(); + } + None => { + continue; + } } - None => { - continue; + // if there is leftover text after removing the infix text, we need to send it + if !json_buffer.is_empty() { + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + let chat_complete = + CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + Some(json_buffer.clone()), + None, + current_time, + None, + None, + None, + )); + yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { + InferError::StreamSerializationError(e.to_string()).into() + })); } + // cleanup the buffers + buffer.clear(); + json_buffer.clear(); + state = StreamState::Content { + skip_close_quote: true, + }; } - // if there is leftover text after removing the infix text, we need to send it - if !json_buffer.is_empty() { - let event = Event::default(); - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - let chat_complete = - CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( - model_id.clone(), - system_fingerprint.clone(), - Some(json_buffer.clone()), - None, - current_time, - None, - None, - None, - )); - yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { - InferError::StreamSerializationError(e.to_string()).into() - })); - } - // cleanup the buffers - buffer.clear(); - json_buffer.clear(); - state = StreamState::Content { - skip_close_quote: true, - }; - } - StreamState::Content { skip_close_quote } => { - if skip_close_quote && token_text.contains('"') { - break; - } + StreamState::Content { skip_close_quote } => { + if skip_close_quote && token_text.contains('"') { + break; + } + // send the content + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); - // send the content - let event = create_event_from_stream_token( - &stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - - yield Ok::(event); + yield Ok::(event); + } } } + Err(err) => { + let error_event: ErrorEvent = err.into(); + let event = Event::default().json_data(error_event).unwrap_or_else(|e| { + InferError::StreamSerializationError(e.to_string()).into() + }); + yield Ok::(event); + break; + } } } yield Ok::(Event::default().data("[DONE]")); @@ -2517,6 +2526,26 @@ impl From for Event { } } +#[derive(serde::Serialize)] +pub struct ErrorWithMessage { + message: String, +} + +#[derive(serde::Serialize)] +pub struct ErrorEvent { + error: ErrorWithMessage, +} + +impl From for ErrorEvent { + fn from(err: InferError) -> Self { + ErrorEvent { + error: ErrorWithMessage { + message: err.to_string(), + }, + } + } +} + #[derive(Debug, Error)] pub enum WebServerError { #[error("Axum error: {0}")]