diff --git a/router/src/batcher.rs b/router/src/batcher.rs index ee83d899..29d7704b 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; /// Batching and inference logic use crate::{Db, Entry}; use crate::{ErrorResponse, GenerateRequest}; @@ -8,6 +9,7 @@ use std::sync::Arc; use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient}; use thiserror::Error; use tokio::sync::{oneshot, Notify}; +use tokio::sync::{Semaphore, TryAcquireError}; use tokio::time::Instant; use tracing::instrument; @@ -24,6 +26,8 @@ pub struct Batcher { struct Shared { /// Batching background Tokio task notifier batching_task: Notify, + /// Inference request limit + limit_concurrent_requests: Semaphore, } impl Batcher { @@ -31,11 +35,13 @@ impl Batcher { client: ShardedClient, max_batch_size: usize, max_waiting_tokens: usize, + max_concurrent_requests: usize, ) -> Self { // Batcher shared state let db = Db::new(); let shared = Arc::new(Shared { batching_task: Notify::new(), + limit_concurrent_requests: Semaphore::new(max_concurrent_requests), }); // Spawn batching background task that contains all the inference logic @@ -56,6 +62,9 @@ impl Batcher { input_length: usize, request: GenerateRequest, ) -> Result { + // Limit concurrent requests by acquiring a permit from the semaphore + let _permit = self.shared.limit_concurrent_requests.try_acquire()?; + // One shot channel to communicate with the background batching task let (response_tx, response_rx) = oneshot::channel(); @@ -104,8 +113,8 @@ async fn batching_task( // Get the next batch from the DB // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the DB - while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) { - let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await; + while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) { + let mut cached_batch = wrap_future(client.generate(batch), &mut entries).await; let mut waiting_tokens = 1; // We loop until we do not receive any cached batch from the inference server (== until @@ -113,7 +122,6 @@ async fn batching_task( while let Some(batch) = cached_batch { // Get current batch info let batch_size = batch.size; - let mut request_ids: Vec = batch.requests.iter().map(|req| req.id).collect(); let mut batches = vec![batch]; // If the current batch is too small, we try to add more requests to it @@ -127,24 +135,24 @@ async fn batching_task( }; // Try to get a new batch - if let Some((new_request_ids, new_batch)) = + if let Some((mut new_entries, new_batch)) = db.next_batch(min_size, max_batch_size - batch_size as usize) { // Generate one token for this new batch to have the attention past in cache let new_cached_batch = - wrap_future(client.generate(new_batch), new_request_ids, &db).await; + wrap_future(client.generate(new_batch), &mut new_entries).await; // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { - request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); + entries.extend(new_entries); batches.push(new_cached_batch); } } } cached_batch = - wrap_future(client.generate_with_cache(batches), request_ids, &db).await; + wrap_future(client.generate_with_cache(batches), &mut entries).await; waiting_tokens += 1; } } @@ -154,39 +162,36 @@ async fn batching_task( /// Wrap a future inside a match statement to handle errors and send the response to the Batcher async fn wrap_future( future: impl Future, Option), ClientError>>, - request_ids: Vec, - db: &Db, + entries: &mut HashMap, ) -> Option { match future.await { Ok((generated_texts, next_batch)) => { - send_generated(generated_texts, db); + send_generated(generated_texts, entries); next_batch } // If we have an error, we discard the whole batch Err(err) => { - send_error(err, request_ids, db); + send_error(err, entries); None } } } -/// Send errors to the Batcher for all `request_ids` -fn send_error(error: ClientError, request_ids: Vec, db: &Db) { - request_ids.into_iter().for_each(|id| { - // We can `expect` here as the request id should always be in the DB - let entry = db.remove(&id).expect("ID not found in db. This is a bug."); +/// Send errors to the Batcher for all `entries` +fn send_error(error: ClientError, entries: &mut HashMap) { + entries.drain().for_each(|(_, entry)| { // unwrap_or is valid here as we don't care if the receiver is gone. entry.response_tx.send(Err(error.clone())).unwrap_or(()); }); } /// Send `generated_text` to the Batcher for all `finished` -fn send_generated(finished: Vec, db: &Db) { +fn send_generated(finished: Vec, entries: &mut HashMap) { finished.into_iter().for_each(|output| { - // We can `expect` here as the request id should always be in the DB - let entry = db + // We can `expect` here as the request id should always be in the entries + let entry = entries .remove(&output.request.unwrap().id) - .expect("ID not found in db. This is a bug."); + .expect("ID not found in entries. This is a bug."); let response = InferResponse { output_text: output.output_text, @@ -221,18 +226,30 @@ pub(crate) struct InferResponse { pub enum InferError { #[error("Request failed during generation: {0}")] GenerationError(String), + #[error("Model is overloaded")] + Overloaded, +} + +/// Convert semaphore error +impl From for InferError { + fn from(_: TryAcquireError) -> Self { + InferError::Overloaded + } } /// Convert to Axum supported format impl From for (StatusCode, Json) { fn from(err: InferError) -> Self { - match err { - InferError::GenerationError(_) => ( - StatusCode::FAILED_DEPENDENCY, - Json(ErrorResponse { - error: err.to_string(), - }), - ), - } + let status_code = match err { + InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY, + InferError::Overloaded => StatusCode::TOO_MANY_REQUESTS, + }; + + ( + status_code, + Json(ErrorResponse { + error: err.to_string(), + }), + ) } } diff --git a/router/src/db.rs b/router/src/db.rs index 1d7df627..f4302f80 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -2,7 +2,7 @@ use crate::InferResponse; /// This code is massively inspired by Tokio mini-redis use crate::{GenerateParameters, GenerateRequest}; use parking_lot::Mutex; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; use text_generation_client::{ Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, @@ -112,18 +112,12 @@ impl Db { state.entries.insert(id, entry); } - /// Remove an entry from the database if it exists - pub(crate) fn remove(&self, id: &u64) -> Option { - let mut state = self.shared.state.lock(); - state.entries.remove(id) - } - // Get the next batch pub(crate) fn next_batch( &self, min_size: Option, max_size: usize, - ) -> Option<(Vec, Batch)> { + ) -> Option<(HashMap, Batch)> { // Acquire lock let mut state = self.shared.state.lock(); @@ -135,13 +129,19 @@ impl Db { return None; } } - ids.iter().for_each(|id| { - // Set batch_time for each request - state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now()); - }); - // Batch size let size = requests.len(); + + let mut entries = HashMap::with_capacity(size); + ids.iter().for_each(|id| { + // Remove entry from db + let mut entry = state.entries.remove(id).unwrap(); + // Set batch_time + entry.batch_time = Some(Instant::now()); + // Insert in entries hashmap + entries.insert(*id, entry); + }); + let batch = Batch { id: state.next_batch_id, requests, @@ -152,7 +152,7 @@ impl Db { // Increment batch id state.next_batch_id += 1; - return Some((ids, batch)); + return Some((entries, batch)); } None } diff --git a/router/src/server.rs b/router/src/server.rs index 623dd07c..fc8f7848 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -7,11 +7,9 @@ use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::{Json, Router}; use std::net::SocketAddr; -use std::sync::Arc; use text_generation_client::ShardedClient; use tokenizers::Tokenizer; use tokio::signal; -use tokio::sync::Semaphore; use tokio::time::Instant; use tracing::instrument; @@ -20,7 +18,6 @@ use tracing::instrument; struct ServerState { validation: Validation, batcher: Batcher, - limit_concurrent_requests: Arc, } /// Health check method @@ -30,16 +27,6 @@ async fn health(state: Extension) -> Result<(), (StatusCode, Json, ) -> Result)> { let start_time = Instant::now(); - // Limit concurrent requests by acquiring a permit from the semaphore - let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { - tracing::error!("Model is overloaded"); - ( - StatusCode::TOO_MANY_REQUESTS, - Json(ErrorResponse { - error: "Model is overloaded".to_string(), - }), - ) - })?; // Validate request let details = req.0.parameters.details; @@ -185,12 +162,16 @@ pub async fn run( addr: SocketAddr, ) { // Create state - let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens); + let batcher = Batcher::new( + client, + max_batch_size, + max_waiting_tokens, + max_concurrent_requests, + ); let validation = Validation::new(validation_workers, tokenizer, max_input_length); let shared_state = ServerState { validation, batcher, - limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)), }; // Create router