diff --git a/Cargo.lock b/Cargo.lock index 752c4886..33f5d181 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1087,6 +1087,12 @@ dependencies = [ "libc", ] +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + [[package]] name = "nom" version = "7.1.1" @@ -1826,6 +1832,7 @@ dependencies = [ "axum", "clap 4.0.22", "futures", + "nohash-hasher", "parking_lot", "serde", "serde_json", diff --git a/router/Cargo.toml b/router/Cargo.toml index f99069d3..546f127f 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -17,6 +17,7 @@ axum = { version = "0.5.16", features = ["json", "serde_json"] } text-generation-client = { path = "client" } clap = { version = "4.0.15", features = ["derive", "env"] } futures = "0.3.24" +nohash-hasher = "0.2.0" parking_lot = "0.12.1" serde = "1.0.145" serde_json = "1.0.85" diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 29d7704b..624ac82d 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -1,15 +1,14 @@ -use std::collections::HashMap; /// Batching and inference logic use crate::{Db, Entry}; use crate::{ErrorResponse, GenerateRequest}; use axum::http::StatusCode; use axum::Json; +use nohash_hasher::IntMap; use std::future::Future; use std::sync::Arc; use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient}; use thiserror::Error; use tokio::sync::{oneshot, Notify}; -use tokio::sync::{Semaphore, TryAcquireError}; use tokio::time::Instant; use tracing::instrument; @@ -26,8 +25,6 @@ pub struct Batcher { struct Shared { /// Batching background Tokio task notifier batching_task: Notify, - /// Inference request limit - limit_concurrent_requests: Semaphore, } impl Batcher { @@ -35,13 +32,11 @@ impl Batcher { client: ShardedClient, max_batch_size: usize, max_waiting_tokens: usize, - max_concurrent_requests: usize, ) -> Self { // Batcher shared state let db = Db::new(); let shared = Arc::new(Shared { batching_task: Notify::new(), - limit_concurrent_requests: Semaphore::new(max_concurrent_requests), }); // Spawn batching background task that contains all the inference logic @@ -62,9 +57,6 @@ impl Batcher { input_length: usize, request: GenerateRequest, ) -> Result { - // Limit concurrent requests by acquiring a permit from the semaphore - let _permit = self.shared.limit_concurrent_requests.try_acquire()?; - // One shot channel to communicate with the background batching task let (response_tx, response_rx) = oneshot::channel(); @@ -151,8 +143,7 @@ async fn batching_task( } } - cached_batch = - wrap_future(client.generate_with_cache(batches), &mut entries).await; + cached_batch = wrap_future(client.generate_with_cache(batches), &mut entries).await; waiting_tokens += 1; } } @@ -162,7 +153,7 @@ async fn batching_task( /// Wrap a future inside a match statement to handle errors and send the response to the Batcher async fn wrap_future( future: impl Future, Option), ClientError>>, - entries: &mut HashMap, + entries: &mut IntMap, ) -> Option { match future.await { Ok((generated_texts, next_batch)) => { @@ -178,7 +169,7 @@ async fn wrap_future( } /// Send errors to the Batcher for all `entries` -fn send_error(error: ClientError, entries: &mut HashMap) { +fn send_error(error: ClientError, entries: &mut IntMap) { entries.drain().for_each(|(_, entry)| { // unwrap_or is valid here as we don't care if the receiver is gone. entry.response_tx.send(Err(error.clone())).unwrap_or(()); @@ -186,7 +177,7 @@ fn send_error(error: ClientError, entries: &mut HashMap) { } /// Send `generated_text` to the Batcher for all `finished` -fn send_generated(finished: Vec, entries: &mut HashMap) { +fn send_generated(finished: Vec, entries: &mut IntMap) { finished.into_iter().for_each(|output| { // We can `expect` here as the request id should always be in the entries let entry = entries @@ -226,30 +217,18 @@ pub(crate) struct InferResponse { pub enum InferError { #[error("Request failed during generation: {0}")] GenerationError(String), - #[error("Model is overloaded")] - Overloaded, -} - -/// Convert semaphore error -impl From for InferError { - fn from(_: TryAcquireError) -> Self { - InferError::Overloaded - } } /// Convert to Axum supported format impl From for (StatusCode, Json) { fn from(err: InferError) -> Self { - let status_code = match err { - InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY, - InferError::Overloaded => StatusCode::TOO_MANY_REQUESTS, - }; - - ( - status_code, - Json(ErrorResponse { - error: err.to_string(), - }), - ) + match err { + InferError::GenerationError(_) => ( + StatusCode::FAILED_DEPENDENCY, + Json(ErrorResponse { + error: err.to_string(), + }), + ), + } } } diff --git a/router/src/db.rs b/router/src/db.rs index f4302f80..51de9d05 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -1,8 +1,9 @@ -use crate::InferResponse; /// This code is massively inspired by Tokio mini-redis +use crate::InferResponse; use crate::{GenerateParameters, GenerateRequest}; +use nohash_hasher::{BuildNoHashHasher, IntMap}; use parking_lot::Mutex; -use std::collections::{BTreeMap, HashMap}; +use std::collections::BTreeMap; use std::sync::Arc; use text_generation_client::{ Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, @@ -117,7 +118,7 @@ impl Db { &self, min_size: Option, max_size: usize, - ) -> Option<(HashMap, Batch)> { + ) -> Option<(IntMap, Batch)> { // Acquire lock let mut state = self.shared.state.lock(); @@ -132,13 +133,13 @@ impl Db { // Batch size let size = requests.len(); - let mut entries = HashMap::with_capacity(size); + let mut entries = IntMap::with_capacity_and_hasher(size, BuildNoHashHasher::default()); ids.iter().for_each(|id| { // Remove entry from db let mut entry = state.entries.remove(id).unwrap(); // Set batch_time entry.batch_time = Some(Instant::now()); - // Insert in entries hashmap + // Insert in entries IntMap entries.insert(*id, entry); }); diff --git a/router/src/server.rs b/router/src/server.rs index fc8f7848..623dd07c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -7,9 +7,11 @@ use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::{Json, Router}; use std::net::SocketAddr; +use std::sync::Arc; use text_generation_client::ShardedClient; use tokenizers::Tokenizer; use tokio::signal; +use tokio::sync::Semaphore; use tokio::time::Instant; use tracing::instrument; @@ -18,6 +20,7 @@ use tracing::instrument; struct ServerState { validation: Validation, batcher: Batcher, + limit_concurrent_requests: Arc, } /// Health check method @@ -27,6 +30,16 @@ async fn health(state: Extension) -> Result<(), (StatusCode, Json, ) -> Result)> { let start_time = Instant::now(); + // Limit concurrent requests by acquiring a permit from the semaphore + let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { + tracing::error!("Model is overloaded"); + ( + StatusCode::TOO_MANY_REQUESTS, + Json(ErrorResponse { + error: "Model is overloaded".to_string(), + }), + ) + })?; // Validate request let details = req.0.parameters.details; @@ -162,16 +185,12 @@ pub async fn run( addr: SocketAddr, ) { // Create state - let batcher = Batcher::new( - client, - max_batch_size, - max_waiting_tokens, - max_concurrent_requests, - ); + let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens); let validation = Validation::new(validation_workers, tokenizer, max_input_length); let shared_state = ServerState { validation, batcher, + limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)), }; // Create router