From 122c137b56381f55b08f2a5b3225f272e61276af Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Sat, 28 Jan 2023 09:31:37 +0100 Subject: [PATCH] rust code cleanup --- router/client/src/client.rs | 6 +- router/client/src/sharded_client.rs | 26 ++-- router/src/db.rs | 6 +- router/src/{batcher.rs => infer.rs} | 108 +++++++++------ router/src/lib.rs | 4 +- router/src/server.rs | 196 ++++++++++++++-------------- router/src/validation.rs | 15 +-- 7 files changed, 191 insertions(+), 170 deletions(-) rename router/src/{batcher.rs => infer.rs} (75%) diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 01bf1f17..77a43110 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -70,7 +70,7 @@ impl Client { /// Generate one token for each request in the given batch /// - /// Returns a list of generated texts of request that met their stopping criteria + /// Returns Generation for each request in batch /// and the next cached batch #[instrument(skip(self))] pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { @@ -84,9 +84,9 @@ impl Client { Ok((response.generations, response.batch)) } - /// Generate one token for each request in the given cached batch + /// Generate one token for each request in the given cached batches /// - /// Returns a list of generated texts of request that met their stopping criteria + /// Returns Generation for each request in batches /// and the next cached batch #[instrument(skip(self))] pub async fn decode( diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 015d431d..56335f92 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -37,9 +37,19 @@ impl ShardedClient { Self::from_master_client(master_client).await } + /// Clear the past generations cache + pub async fn clear_cache(&mut self) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache()) + .collect(); + join_all(futures).await.into_iter().collect() + } + /// Generate one token for each request in the given batch /// - /// Returns a list of generated texts of request that met their stopping criteria + /// Returns Generation for each request in batch /// and the next cached batch pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { let futures: Vec<_> = self @@ -52,9 +62,9 @@ impl ShardedClient { result } - /// Generate one token for each request in the given cached batch + /// Generate one token for each request in the given cached batches /// - /// Returns a list of generated texts of request that met their stopping criteria + /// Returns Generation for each request in batches /// and the next cached batch pub async fn decode( &mut self, @@ -69,14 +79,4 @@ impl ShardedClient { let (result, _, _) = select_all(futures).await; result } - - /// Clear the past generations cache - pub async fn clear_cache(&mut self) -> Result<()> { - let futures: Vec<_> = self - .clients - .iter_mut() - .map(|client| client.clear_cache()) - .collect(); - join_all(futures).await.into_iter().collect() - } } diff --git a/router/src/db.rs b/router/src/db.rs index 34b26599..8997094c 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -1,6 +1,6 @@ -use crate::batcher::InferError; /// This code is massively inspired by Tokio mini-redis -use crate::batcher::InferStreamResponse; +use crate::infer::InferError; +use crate::infer::InferStreamResponse; use crate::{GenerateParameters, GenerateRequest}; use nohash_hasher::{BuildNoHashHasher, IntMap}; use parking_lot::Mutex; @@ -17,7 +17,7 @@ use tokio::time::Instant; pub(crate) struct Entry { /// Request pub request: GenerateRequest, - /// Response sender to communicate between the Batcher and the batching_task + /// Response sender to communicate between the Infer struct and the batching_task pub response_tx: UnboundedSender>, /// Number of tokens in the input pub input_length: usize, diff --git a/router/src/batcher.rs b/router/src/infer.rs similarity index 75% rename from router/src/batcher.rs rename to router/src/infer.rs index d85efedf..62cd0248 100644 --- a/router/src/batcher.rs +++ b/router/src/infer.rs @@ -1,43 +1,49 @@ /// Batching and inference logic +use crate::validation::{Validation, ValidationError}; +use crate::GenerateRequest; use crate::{Db, Entry, Token}; -use crate::{ErrorResponse, GenerateRequest}; -use axum::http::StatusCode; -use axum::Json; use nohash_hasher::IntMap; use std::future::Future; use std::sync::Arc; use text_generation_client::{Batch, ClientError, GeneratedText, Generation, ShardedClient}; use thiserror::Error; -use tokio::sync::{mpsc, Notify}; +use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::instrument; -/// Batcher +/// Inference struct #[derive(Clone)] -pub struct Batcher { +pub struct Infer { + /// Validation + validation: Validation, /// Request database db: Db, /// Shared state shared: Arc, } -/// Batcher shared state +/// Infer shared state struct Shared { + /// Inference limit + limit_concurrent_requests: Semaphore, /// Batching background Tokio task notifier batching_task: Notify, } -impl Batcher { +impl Infer { pub(crate) fn new( client: ShardedClient, + validation: Validation, max_batch_size: usize, max_waiting_tokens: usize, + max_concurrent_requests: usize, ) -> Self { - // Batcher shared state + // Infer shared state let db = Db::new(); let shared = Arc::new(Shared { + limit_concurrent_requests: Semaphore::new(max_concurrent_requests), batching_task: Notify::new(), }); @@ -50,21 +56,30 @@ impl Batcher { shared.clone(), )); - Self { db, shared } + Self { + validation, + db, + shared, + } } - /// Add a new request to the database and return a stream of tokens - pub(crate) fn infer_stream( + /// Add a new request to the database and return a stream of InferStreamResponse + pub(crate) async fn generate_stream( &self, - input_length: usize, request: GenerateRequest, - ) -> UnboundedReceiverStream> { + ) -> Result>, InferError> { + // Limit concurrent requests by acquiring a permit from the semaphore + let _permit = self.shared.limit_concurrent_requests.try_acquire()?; + + // Validate request + let (input_length, validated_request) = self.validation.validate(request).await?; + // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); // Try to append the request to the database self.db.append(Entry { - request, + request: validated_request, response_tx, input_length, time: Instant::now(), @@ -76,27 +91,34 @@ impl Batcher { self.shared.batching_task.notify_one(); // Return stream - UnboundedReceiverStream::new(response_rx) + Ok(UnboundedReceiverStream::new(response_rx)) } - pub(crate) async fn infer( + /// Add a new request to the database and return a InferResponse + pub(crate) async fn generate( &self, - input_length: usize, request: GenerateRequest, ) -> Result { - let mut stream = self.infer_stream(input_length, request); + // Create stream + let mut stream = self.generate_stream(request).await?; + // Return values let mut result_tokens = Vec::new(); let mut result_generated_text = None; let mut result_start = None; let mut result_queued = None; + // Iterate on stream while let Some(response) = stream.next().await { match response? { + // Add prefill tokens InferStreamResponse::Prefill(prefill_tokens) => { result_tokens.extend(prefill_tokens) } + // Push last token InferStreamResponse::Token(token) => result_tokens.push(token), + // Final message + // Set return values InferStreamResponse::End { generated_text, start, @@ -108,6 +130,7 @@ impl Batcher { } } } + // Unwrap is safe here Ok(InferResponse { tokens: result_tokens, generated_text: result_generated_text.unwrap(), @@ -134,7 +157,7 @@ async fn batching_task( // Infinite loop loop { - // Wait for a notification from the Batcher struct + // Wait for a notification from the Infer struct shared.batching_task.notified().await; // Get the next batch from the DB @@ -185,14 +208,14 @@ async fn batching_task( } } -/// Wrap a future inside a match statement to handle errors and send the response to the Batcher +/// Wrap a future inside a match statement to handle errors and send the responses to Infer async fn wrap_future( future: impl Future, Option), ClientError>>, entries: &mut IntMap, ) -> Option { match future.await { Ok((generations, next_batch)) => { - send_generated(generations, entries); + send_generations(generations, entries); next_batch } // If we have an error, we discard the whole batch @@ -203,7 +226,7 @@ async fn wrap_future( } } -/// Send errors to the Batcher for all `entries` +/// Send errors to Infer for all `entries` fn send_error(error: ClientError, entries: &mut IntMap) { entries.drain().for_each(|(_, entry)| { // unwrap_or is valid here as we don't care if the receiver is gone. @@ -214,14 +237,18 @@ fn send_error(error: ClientError, entries: &mut IntMap) { }); } -/// Send `generated_text` to the Batcher for all `finished` -fn send_generated(generations: Vec, entries: &mut IntMap) { +/// Send one or multiple `InferStreamResponse` to Infer for all `entries` +fn send_generations(generations: Vec, entries: &mut IntMap) { generations.into_iter().for_each(|generation| { + // Get entry + // We can `expect` here as the request id should always be in the entries let entry = entries .get(&generation.request_id) .expect("ID not found in entries. This is a bug."); if let Some(prefill_tokens) = generation.prefill_tokens { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster let tokens = prefill_tokens .ids .into_iter() @@ -229,27 +256,37 @@ fn send_generated(generations: Vec, entries: &mut IntMap .zip(prefill_tokens.texts.into_iter()) .map(|((id, logprob), text)| Token(id, text, logprob)) .collect(); + // Send message + // unwrap_or is valid here as we don't care if the receiver is gone. entry .response_tx .send(Ok(InferStreamResponse::Prefill(tokens))) .unwrap_or(()); } + // Create last Token let token = Token( generation.token_id, generation.token_text, generation.token_logprob, ); + + // Send message + // unwrap_or is valid here as we don't care if the receiver is gone. entry .response_tx .send(Ok(InferStreamResponse::Token(token))) .unwrap_or(()); if let Some(generated_text) = generation.generated_text { + // Remove entry as this is the last message + // We can `expect` here as the request id should always be in the entries let entry = entries .remove(&generation.request_id) .expect("ID not found in entries. This is a bug."); + // Send message + // unwrap_or is valid here as we don't care if the receiver is gone. entry .response_tx .send(Ok(InferStreamResponse::End { @@ -264,8 +301,11 @@ fn send_generated(generations: Vec, entries: &mut IntMap #[derive(Debug)] pub(crate) enum InferStreamResponse { + // Optional first message Prefill(Vec), + // Intermediate messages Token(Token), + // Last message End { generated_text: GeneratedText, start: Instant, @@ -286,18 +326,8 @@ pub(crate) struct InferResponse { pub enum InferError { #[error("Request failed during generation: {0}")] GenerationError(String), -} - -/// Convert to Axum supported format -impl From for (StatusCode, Json) { - fn from(err: InferError) -> Self { - match err { - InferError::GenerationError(_) => ( - StatusCode::FAILED_DEPENDENCY, - Json(ErrorResponse { - error: err.to_string(), - }), - ), - } - } + #[error("Model is overloaded")] + Overloaded(#[from] TryAcquireError), + #[error("Input validation error: {0}")] + ValidationError(#[from] ValidationError), } diff --git a/router/src/lib.rs b/router/src/lib.rs index aab253b7..de3b7d78 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,11 +1,11 @@ /// Text Generation Inference Webserver -mod batcher; mod db; +mod infer; pub mod server; mod validation; -use batcher::Batcher; use db::{Db, Entry}; +use infer::Infer; use serde::{Deserialize, Serialize}; use validation::Validation; diff --git a/router/src/server.rs b/router/src/server.rs index cda38fbe..56161597 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,75 +1,52 @@ -use crate::batcher::InferStreamResponse; +/// HTTP Server logic +use crate::infer::{InferError, InferStreamResponse}; use crate::{ - Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation, + Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::IntoResponse; use axum::routing::{get, post}; -use axum::{BoxError, Json, Router}; +use axum::{Json, Router}; use futures::Stream; use std::net::SocketAddr; -use std::sync::Arc; use text_generation_client::ShardedClient; use tokenizers::Tokenizer; use tokio::signal; -use tokio::sync::Semaphore; use tokio::time::Instant; use tokio_stream::StreamExt; use tracing::instrument; -// Server shared state -#[derive(Clone)] -struct ServerState { - validation: Validation, - batcher: Batcher, - limit_concurrent_requests: Arc, -} - /// Health check method -#[instrument(skip(state), fields(time, time_per_token))] -async fn health(state: Extension) -> Result<(), (StatusCode, Json)> { +#[instrument(skip(infer))] +async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { // TODO: while this is the best health check we can do, it is a bit on the heavy side and might // be a bit too slow for a health check. // What we should do instead if check if the gRPC channels are still healthy. - // Limit concurrent requests by acquiring a permit from the semaphore - let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { - ( - StatusCode::TOO_MANY_REQUESTS, - Json(ErrorResponse { - error: "Model is overloaded".to_string(), - }), - ) - })?; - // Send a small inference request - state - .batcher - .infer( - 1, - GenerateRequest { - inputs: "liveness".to_string(), - parameters: GenerateParameters { - temperature: 1.0, - top_k: 0, - top_p: 1.0, - do_sample: false, - max_new_tokens: 1, - stop: vec![], - details: false, - seed: None, - }, + infer + .generate(GenerateRequest { + inputs: "liveness".to_string(), + parameters: GenerateParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + do_sample: false, + max_new_tokens: 1, + stop: vec![], + details: false, + seed: None, }, - ) + }) .await?; Ok(()) } /// Generate method #[instrument( - skip(state), + skip(infer), fields( total_time, validation_time, @@ -80,38 +57,17 @@ async fn health(state: Extension) -> Result<(), (StatusCode, Json, + infer: Extension, req: Json, ) -> Result)> { let start_time = Instant::now(); - // Limit concurrent requests by acquiring a permit from the semaphore - let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { - tracing::error!("Model is overloaded"); - ( - StatusCode::TOO_MANY_REQUESTS, - Json(ErrorResponse { - error: "Model is overloaded".to_string(), - }), - ) - })?; - - // Validate request - let details = req.0.parameters.details; - let (input_length, validated_request) = - state.validation.validate(req.0).await.map_err(|err| { - tracing::error!("{}", err.to_string()); - err - })?; // Inference - let response = state - .batcher - .infer(input_length, validated_request) - .await - .map_err(|err| { - tracing::error!("{}", err.to_string()); - err - })?; + let details = req.0.parameters.details; + let response = infer.generate(req.0).await.map_err(|err| { + tracing::error!("{}", err.to_string()); + err + })?; // Token details let details = match details { @@ -171,39 +127,68 @@ async fn generate( Ok((headers, Json(response))) } +/// Generate stream method +#[instrument( + skip(infer), + fields( + total_time, + validation_time, + queue_time, + inference_time, + time_per_token + ) +)] async fn generate_stream( - state: Extension, + infer: Extension, req: Json, -) -> Sse>> { +) -> Sse>> { let stream = async_stream::stream! { - // Limit concurrent requests by acquiring a permit from the semaphore - let _permit = state.limit_concurrent_requests.try_acquire().map_err(| err | { - tracing::error!("Model is overloaded"); - err - })?; - - // Validate request - let (input_length, validated_request) = - state.validation.validate(req.0).await.map_err(|err| { - tracing::error!("{}", err); - err - })?; + let start_time = Instant::now(); // Inference - let mut response_stream = state - .batcher - .infer_stream(input_length, validated_request); + let mut response_stream = infer.generate_stream(req.0).await?; + // Server Side Event stream while let Some(response) = response_stream.next().await { match response { Ok(response) => { - if let InferStreamResponse::Token(token) = response { - yield Ok(Event::default().json_data(token).unwrap()); + match response { + // Prefill is ignored + InferStreamResponse::Prefill(_) => {} + // Yield event for every new token + InferStreamResponse::Token(token) => { + yield Ok(Event::default().json_data(token).unwrap()) + } + // End is used for timings metadata and logging + InferStreamResponse::End { + generated_text, + start, + queued, + } => { + // Timings + let total_time = start_time.elapsed(); + let validation_time = queued - start_time; + let queue_time = start - queued; + let inference_time = Instant::now() - start; + let time_per_token = inference_time / generated_text.generated_tokens; + + // Tracing metadata + tracing::Span::current().record("total_time", format!("{:?}", total_time)); + tracing::Span::current() + .record("validation_time", format!("{:?}", validation_time)); + tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); + tracing::Span::current() + .record("inference_time", format!("{:?}", inference_time)); + tracing::Span::current() + .record("time_per_token", format!("{:?}", time_per_token)); + tracing::info!("Output: {}", generated_text.text); + } } } + // Trace and yield error Err(err) => { tracing::error!("{}", err.to_string()); - yield Ok(Event::default().data(err.to_string())); + yield Err(err); } } } @@ -225,13 +210,14 @@ pub async fn run( addr: SocketAddr, ) { // Create state - let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens); let validation = Validation::new(validation_workers, tokenizer, max_input_length); - let shared_state = ServerState { + let infer = Infer::new( + client, validation, - batcher, - limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)), - }; + max_batch_size, + max_waiting_tokens, + max_concurrent_requests, + ); // Create router let app = Router::new() @@ -240,7 +226,7 @@ pub async fn run( .route("/generate_stream", post(generate_stream)) .route("/", get(health)) .route("/health", get(health)) - .layer(Extension(shared_state.clone())); + .layer(Extension(infer)); // Run server axum::Server::bind(&addr) @@ -277,3 +263,21 @@ async fn shutdown_signal() { tracing::info!("signal received, starting graceful shutdown"); } + +/// Convert to Axum supported format +impl From for (StatusCode, Json) { + fn from(err: InferError) -> Self { + let status_code = match err { + InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY, + InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, + InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, + }; + + ( + status_code, + Json(ErrorResponse { + error: err.to_string(), + }), + ) + } +} diff --git a/router/src/validation.rs b/router/src/validation.rs index 674635d3..d9579774 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,7 +1,5 @@ /// Payload validation logic -use crate::{ErrorResponse, GenerateRequest}; -use axum::http::StatusCode; -use axum::Json; +use crate::GenerateRequest; use rand::rngs::ThreadRng; use rand::Rng; use thiserror::Error; @@ -172,14 +170,3 @@ pub enum ValidationError { #[error("tokenizer error {0}")] Tokenizer(String), } - -impl From for (StatusCode, Json) { - fn from(err: ValidationError) -> Self { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - }), - ) - } -}