diff --git a/router/Cargo.toml b/router/Cargo.toml index 2a51773d..b9759465 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -22,8 +22,14 @@ serde = "1.0.145" serde_json = "1.0.85" thiserror = "1.0.37" tokenizers = "0.13.0" -tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "net"] } +tokio = { version = "1.21.1", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", + "net", +] } tracing = "0.1.36" tracing-subscriber = { version = "0.3.15", features = ["json"] } async-stream = "0.3.3" - diff --git a/router/src/batcher.rs b/router/src/batcher.rs index a4059d13..7e9d0f01 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -5,9 +5,9 @@ use axum::http::StatusCode; use axum::Json; use std::future::Future; use std::sync::Arc; -use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient, Intermediate}; +use text_generation_client::{Batch, ClientError, GeneratedText, Intermediate, ShardedClient}; use thiserror::Error; -use tokio::sync::{oneshot, Notify, mpsc}; +use tokio::sync::{mpsc, oneshot, Notify}; use tokio::time::Instant; use tracing::instrument; @@ -71,13 +71,6 @@ impl Batcher { // Notify the background task that we have a new entry in the database that needs // to be batched self.shared.batching_task.notify_one(); - - // // Await on the response from the background task - // // We can safely unwrap as the background task will never drop the sender - // response_rx - // .await - // .unwrap() - // .map_err(|err| InferError::GenerationError(err.to_string())) } /// Add a new request to the database and return a future that will generate the text @@ -184,7 +177,9 @@ async fn batching_task( /// Wrap a future inside a match statement to handle errors and send the response to the Batcher async fn wrap_future( - future: impl Future, Option, Vec), ClientError>>, + future: impl Future< + Output = Result<(Vec, Option, Vec), ClientError>, + >, request_ids: Vec, db: &Db, ) -> Option { @@ -217,13 +212,15 @@ fn send_generated(finished: Vec, intermediates: Vec intermediates.into_iter().for_each(|intermediate| { // We can `expect` here as the request id should always be in the DB let guard = db.get_mutex_guard(); - let entry = guard.entries.get(&intermediate.request_id).expect("ID not found in db. This is a bug."); - + let entry = guard + .entries + .get(&intermediate.request_id) + .expect("ID not found in db. This is a bug."); if let Some(tx) = &entry.intermediate_tx { - // unwrap_or is valid here as we don't care if the receiver is gone. - tx.send(Ok(Some(intermediate))).unwrap_or(()); - } + // unwrap_or is valid here as we don't care if the receiver is gone. + tx.send(Ok(Some(intermediate))).unwrap_or(()); + } }); finished.into_iter().for_each(|output| { @@ -231,7 +228,7 @@ fn send_generated(finished: Vec, intermediates: Vec let entry = db .remove(&output.request.unwrap().id) .expect("ID not found in db. This is a bug."); - + if let Some(tx) = &entry.intermediate_tx { tx.send(Ok(None)).unwrap_or(()); } diff --git a/router/src/server.rs b/router/src/server.rs index 14418a28..e60d94c8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,17 +8,17 @@ use axum::routing::{get, post}; use axum::{Json, Router}; use std::net::SocketAddr; use std::sync::Arc; -use text_generation_client::{ShardedClient, IntermediateEvent}; +use text_generation_client::{IntermediateEvent, ShardedClient}; use tokenizers::Tokenizer; use tokio::signal; use tokio::sync::Semaphore; +use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::instrument; -use tokio::sync::{oneshot, mpsc}; use axum::response::sse::{Event, KeepAlive, Sse}; -use std::convert::Infallible; use futures::stream::Stream; +use std::convert::Infallible; // Server shared state #[derive(Clone)] @@ -81,62 +81,71 @@ async fn generate_stream( let (intermediate_tx, mut intermediate_rx) = mpsc::unbounded_channel(); let (response_tx, response_rx) = oneshot::channel(); - let (input_length, validated_request) = - state.validation.validate(req.0).await.map_err(|err| { + let (input_length, validated_request) = state + .validation + .validate(req.0) + .await + .map_err(|err| { tracing::error!("{}", err.to_string()); err - }).unwrap(); + }) + .unwrap(); // Inference - state.batcher.infer_stream(input_length, validated_request, intermediate_tx, response_tx); - - let stream = async_stream::stream! { - while let Some(item) = intermediate_rx.recv().await { - match item { - Ok(item) => { - match item { - Some(item) => { - let event_data = IntermediateEvent { - token: item.token, - token_id: item.token_id, - logprob: item.logprob, - }; - let stream_event = StreamEvent { - is_end: false, - event: Some(event_data), - generated_text: None, - }; - yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap())); - } - None => { - break - } + state.batcher.infer_stream( + input_length, + validated_request, + intermediate_tx, + response_tx, + ); + + let stream = async_stream::stream! { + while let Some(item) = intermediate_rx.recv().await { + match item { + Ok(item) => { + match item { + Some(item) => { + let event_data = IntermediateEvent { + token: item.token, + token_id: item.token_id, + logprob: item.logprob, + }; + let stream_event = StreamEvent { + is_end: false, + event: Some(event_data), + generated_text: None, + }; + yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap())); + } + None => { + break } } - Err(err) => { - yield Ok(Event::default().data(err.to_string())); - } - } - } - let response = response_rx.await.unwrap(); - match response { - Ok(response) => { - let response = GeneratedText { - generated_text: response.output_text, - details: None, - }; - let stream_event = StreamEvent { - is_end: true, - event: None, - generated_text: Some(response), - }; - yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap())); } Err(err) => { yield Ok(Event::default().data(err.to_string())); } } - }; + } + let response = response_rx.await.unwrap(); + match response { + Ok(response) => { + let response = GeneratedText { + generated_text: response.output_text, + details: None, + }; + let stream_event = StreamEvent { + is_end: true, + event: None, + generated_text: Some(response), + }; + yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap())); + } + Err(err) => { + yield Ok(Event::default().data(err.to_string())); + } + } + }; Sse::new(stream).keep_alive(KeepAlive::default()) }