diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 557e03cb..ea1fc773 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -14,6 +14,7 @@ use chat_template::ChatTemplate; use futures::future::try_join_all; use futures::Stream; use minijinja::ErrorKind; +use serde::{Deserialize, Serialize}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use thiserror::Error; @@ -373,4 +374,25 @@ impl InferError { InferError::StreamSerializationError(_) => "stream_serialization_error", } } + + pub(crate) fn into_openai_event(self) -> Event { + let message = self.to_string(); + Event::default().json_data(OpenaiErrorEvent { + error: APIError { + message, + http_status_code: 422, + }, + }) + } +} + +#[derive(Serialize)] +pub struct APIError { + message: String, + http_status_code: usize, +} + +#[derive(Serialize)] +pub struct OpenaiErrorEvent { + error: APIError, } diff --git a/router/src/server.rs b/router/src/server.rs index 41b59665..f0469ca5 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -866,14 +866,7 @@ pub(crate) async fn completions( yield Ok(event); } - Err(err) => { - let event = Event::default() - .json_data(ErrorEvent::into_api_error(err, 422)) - .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()); - println!("{:?}", event); - yield Ok::(event); - break - } + Err(err) => yield Ok(err.into_openai_event()), } } }; @@ -1281,107 +1274,102 @@ pub(crate) async fn chat_completions( }; let mut response_as_tool = using_tools; while let Some(result) = response_stream.next().await { - match result { - Ok(stream_token) => { - let token_text = &stream_token.token.text.clone(); - match state { - StreamState::Buffering => { - json_buffer.push_str(&token_text.replace(" ", "")); - buffer.push(stream_token); - if let Some(captures) = function_regex.captures(&json_buffer) { - let function_name = captures[1].to_string(); - if function_name == "no_tool" { - state = StreamState::BufferTrailing; - response_as_tool = false; - buffer.clear(); - json_buffer.clear(); - } else { - state = StreamState::Content { - skip_close_quote: false, - }; - // send all the buffered messages - for stream_token in &buffer { - let event = create_event_from_stream_token( - stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - yield Ok::(event); - } - } - } - } - // if we skipped sending the buffer we need to avoid sending the following json key and quotes - StreamState::BufferTrailing => { - let infix_text = "\"content\":\""; - json_buffer.push_str(&token_text.replace(" ", "")); - // keep capturing until we find the infix text - match json_buffer.find(infix_text) { - Some(content_key_index) => { - json_buffer = - json_buffer[content_key_index + infix_text.len()..].to_string(); - } - None => { - continue; - } - } - // if there is leftover text after removing the infix text, we need to send it - if !json_buffer.is_empty() { - let event = Event::default(); - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - let chat_complete = - CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( - model_id.clone(), + match result{ + Ok(stream_tokens) => { + let token_text = &stream_token.token.text.clone(); + match state { + StreamState::Buffering => { + json_buffer.push_str(&token_text.replace(" ", "")); + buffer.push(stream_token); + if let Some(captures) = function_regex.captures(&json_buffer) { + let function_name = captures[1].to_string(); + if function_name == "no_tool" { + state = StreamState::BufferTrailing; + response_as_tool = false; + buffer.clear(); + json_buffer.clear(); + } else { + state = StreamState::Content { + skip_close_quote: false, + }; + // send all the buffered messages + for stream_token in &buffer { + let event = create_event_from_stream_token( + stream_token, + logprobs, + stream_options.clone(), + response_as_tool, system_fingerprint.clone(), - Some(json_buffer.clone()), - None, - current_time, - None, - None, - None, - )); - yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { - InferError::StreamSerializationError(e.to_string()).into() - })); + model_id.clone(), + ); + yield Ok::(event); + } } - // cleanup the buffers - buffer.clear(); - json_buffer.clear(); - state = StreamState::Content { - skip_close_quote: true, - }; - } - StreamState::Content { skip_close_quote } => { - if skip_close_quote && token_text.contains('"') { - break; - } - // send the content - let event = create_event_from_stream_token( - &stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - - yield Ok::(event); } } + // if we skipped sending the buffer we need to avoid sending the following json key and quotes + StreamState::BufferTrailing => { + let infix_text = "\"content\":\""; + json_buffer.push_str(&token_text.replace(" ", "")); + // keep capturing until we find the infix text + match json_buffer.find(infix_text) { + Some(content_key_index) => { + json_buffer = + json_buffer[content_key_index + infix_text.len()..].to_string(); + } + None => { + continue; + } + } + // if there is leftover text after removing the infix text, we need to send it + if !json_buffer.is_empty() { + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + let chat_complete = + CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + Some(json_buffer.clone()), + None, + current_time, + None, + None, + None, + )); + yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { + InferError::StreamSerializationError(e.to_string()).into() + })); + } + // cleanup the buffers + buffer.clear(); + json_buffer.clear(); + state = StreamState::Content { + skip_close_quote: true, + }; + } + StreamState::Content { skip_close_quote } => { + if skip_close_quote && token_text.contains('"') { + break; + } + + // send the content + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + + yield Ok::(event); + } } - Err(err) => { - let event = Event::default() - .json_data(ErrorEvent::into_api_error(err, 422)) - .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()); - yield Ok::(event); - break; - } + }, + Err(err) => yield Event::from_openai(err) } } yield Ok::(Event::default().data("[DONE]")); @@ -2532,28 +2520,6 @@ impl From for Event { } } -#[derive(serde::Serialize)] -pub struct APIError { - message: String, - http_status_code: usize, -} - -#[derive(serde::Serialize)] -pub struct ErrorEvent { - error: APIError, -} - -impl ErrorEvent { - fn into_api_error(err: InferError, http_status_code: usize) -> Self { - ErrorEvent { - error: APIError { - message: err.to_string(), - http_status_code, - }, - } - } -} - #[derive(Debug, Error)] pub enum WebServerError { #[error("Axum error: {0}")]