diff --git a/router/src/infer.rs b/router/src/infer.rs index 3ce4923c..eb489b5f 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -140,24 +140,20 @@ impl Infer { /// Apply the chat template to the chat request #[instrument(skip_all)] - pub(crate) fn apply_chat_template( - &self, - chat: ChatRequest, - ) -> Result { + + pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result { let mut env = minijinja::Environment::new(); let chat_template = self .tokenizer_config .chat_template .as_ref() - .ok_or(ChatTemplateError::TemplateNotFound)?; - env.add_template("_", chat_template) - .map_err(|e| ChatTemplateError::TemplateError(e))?; - let jinja_tmpl = env - .get_template("_") - .map_err(|e| ChatTemplateError::TemplateError(e))?; - jinja_tmpl + .ok_or_else(|| { + InferError::TemplateError(minijinja::ErrorKind::TemplateNotFound.into()) + })?; + env.add_template("_", chat_template)?; + env.get_template("_")? .render(chat) - .map_err(|e| ChatTemplateError::TemplateError(e)) + .map_err(InferError::TemplateError) } /// Add a new request to the queue and return a InferResponse @@ -570,9 +566,9 @@ fn send_responses( let mut iterator = tokens_ .ids .into_iter() - .zip(tokens_.logprobs.into_iter()) - .zip(tokens_.texts.into_iter()) - .zip(tokens_.is_special.into_iter()) + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) .enumerate() .peekable(); while let Some((i, (((id, logprob), text), special))) = iterator.next() { @@ -681,6 +677,8 @@ pub enum InferError { ValidationError(#[from] ValidationError), #[error("Incomplete generation")] IncompleteGeneration, + #[error("Template error: {0}")] + TemplateError(#[from] minijinja::Error), } impl InferError { @@ -690,23 +688,7 @@ impl InferError { InferError::Overloaded(_) => "overloaded", InferError::ValidationError(_) => "validation", InferError::IncompleteGeneration => "incomplete_generation", - } - } -} - -#[derive(Debug, Error)] -pub enum ChatTemplateError { - #[error("Template error: {0}")] - TemplateError(#[from] minijinja::Error), - #[error("Template not found")] - TemplateNotFound, -} - -impl ChatTemplateError { - pub(crate) fn error_type(&self) -> &str { - match self { - ChatTemplateError::TemplateError(_) => "template_error", - ChatTemplateError::TemplateNotFound => "template_not_found", + InferError::TemplateError(_) => "template_error", } } } diff --git a/router/src/server.rs b/router/src/server.rs index d4521d96..323bf35f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -21,6 +21,7 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; +use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use text_generation_client::{ShardInfo, ShardedClient}; use tokenizers::Tokenizer; @@ -343,7 +344,7 @@ async fn generate_stream( event.json_data(stream_token).unwrap() }; let (headers, response_stream) = - generate_stream_internal(infer, Json(req.into()), on_message_callback).await; + generate_stream_internal(infer, Json(req), on_message_callback).await; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); (headers, sse) } @@ -547,7 +548,7 @@ async fn generate_stream_internal( seed, ) )] -async fn chat( +async fn chat_completions( Extension(infer): Extension, Json(req): Json, ) -> Result)> { @@ -557,7 +558,7 @@ async fn chat( let stream = req.stream; let max_new_tokens = match req.max_tokens { Some(max_new_tokens) => Some(max_new_tokens), - None => Some(100) + None => Some(100), }; // apply chat template to flatten the request into a single input @@ -604,8 +605,9 @@ async fn chat( // switch on stream if stream { + let stream_count = AtomicU32::new(0); // pass this callback to the stream generation and build the required event structure - let on_message_callback = |stream_token: StreamResponse| { + let on_message_callback = move |stream_token: StreamResponse| { let event = Event::default(); let current_time = std::time::SystemTime::now() @@ -613,11 +615,15 @@ async fn chat( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); + // increment the stream count + stream_count.fetch_add(1, Ordering::SeqCst); + let current_stream_count = stream_count.load(Ordering::SeqCst); + event .json_data(ChatCompletionChunk::new( stream_token.token.text, current_time, - 0, + current_stream_count, )) .unwrap_or_else(|_| { println!("Failed to serialize ChatCompletionChunk"); @@ -843,7 +849,7 @@ pub async fn run( .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) - .route("/v1/chat/completions", post(chat)) + .route("/v1/chat/completions", post(chat_completions)) // AWS Sagemaker route .route("/invocations", post(compat_generate)) // Base Health route @@ -973,6 +979,7 @@ impl From for (StatusCode, Json) { InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, + InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, }; (