diff --git a/Cargo.lock b/Cargo.lock index 752c4886..33f5d181 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1087,6 +1087,12 @@ dependencies = [ "libc", ] +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + [[package]] name = "nom" version = "7.1.1" @@ -1826,6 +1832,7 @@ dependencies = [ "axum", "clap 4.0.22", "futures", + "nohash-hasher", "parking_lot", "serde", "serde_json", diff --git a/router/Cargo.toml b/router/Cargo.toml index f99069d3..546f127f 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -17,6 +17,7 @@ axum = { version = "0.5.16", features = ["json", "serde_json"] } text-generation-client = { path = "client" } clap = { version = "4.0.15", features = ["derive", "env"] } futures = "0.3.24" +nohash-hasher = "0.2.0" parking_lot = "0.12.1" serde = "1.0.145" serde_json = "1.0.85" diff --git a/router/src/batcher.rs b/router/src/batcher.rs index ee83d899..df63a375 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -1,66 +1,70 @@ /// Batching and inference logic -use crate::{Db, Entry}; +use crate::Entry; use crate::{ErrorResponse, GenerateRequest}; use axum::http::StatusCode; use axum::Json; use std::future::Future; -use std::sync::Arc; +use nohash_hasher::IntMap; use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient}; use thiserror::Error; -use tokio::sync::{oneshot, Notify}; +use tokio::sync::oneshot; +use tokio::sync::mpsc::{channel, Permit, Sender}; +use tokio::sync::mpsc::error::TrySendError; use tokio::time::Instant; use tracing::instrument; +use crate::queue::Queue; /// Batcher #[derive(Clone)] pub struct Batcher { - /// Request database - db: Db, - /// Shared state - shared: Arc, + /// Request queue + sender: Sender, } -/// Batcher shared state -struct Shared { - /// Batching background Tokio task notifier - batching_task: Notify, -} impl Batcher { pub(crate) fn new( client: ShardedClient, max_batch_size: usize, max_waiting_tokens: usize, + queue_size: usize, ) -> Self { - // Batcher shared state - let db = Db::new(); - let shared = Arc::new(Shared { - batching_task: Notify::new(), - }); + // Set up queue + let (sender, receiver) = channel(queue_size); // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( client, max_batch_size, max_waiting_tokens, - db.clone(), - shared.clone(), + Queue::new(receiver), )); - Self { db, shared } + Self { sender } } - /// Add a new request to the database and return a future that will generate the text + /// Reserve a slot in the queue for sending a request + pub(crate) fn reserve_slot(&self) -> Result, TrySendError<()>> { + self.sender.try_reserve().map(|permit| RequestSender { permit }) + } +} + +pub(crate) struct RequestSender<'a> { + permit: Permit<'a, Entry> +} + +impl <'a> RequestSender<'a> { + /// Add a new request to the queue and return a future that will generate the text pub(crate) async fn infer( - &self, + self, input_length: usize, request: GenerateRequest, ) -> Result { // One shot channel to communicate with the background batching task let (response_tx, response_rx) = oneshot::channel(); - // Try to append the request to the database - self.db.append(Entry { + // Try to enqueue the request + self.permit.send(Entry { request, response_tx, input_length, @@ -68,10 +72,6 @@ impl Batcher { batch_time: None, }); - // Notify the background task that we have a new entry in the database that needs - // to be batched - self.shared.batching_task.notify_one(); - // Await on the response from the background task // We can safely unwrap as the background task will never drop the sender response_rx @@ -85,68 +85,69 @@ impl Batcher { /// Will be launched in a background Tokio task /// /// Batches requests and sends them to the inference server -#[instrument(skip(client, db, shared))] +#[instrument(skip(client, queue))] async fn batching_task( mut client: ShardedClient, max_batch_size: usize, max_waiting_tokens: usize, - db: Db, - shared: Arc, + mut queue: Queue, ) { // Minimum batch size after which we try to add more requests let limit_min_batch_size = (max_batch_size / 2) as u32; - // Infinite loop - loop { - // Wait for a notification from the Batcher struct - shared.batching_task.notified().await; + // Entries corresponding to all of the in-progress requests + let mut entries = IntMap::default(); - // Get the next batch from the DB - // This batch might be smaller than the maximum batch size if there are not enough requests - // waiting in the DB - while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) { - let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await; - let mut waiting_tokens = 1; + // Get the next batch from the queue + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the queue + while let Some(batch) = queue.next_batch(max_batch_size, &mut entries).await { + let mut cached_batch = wrap_future( + client.generate(batch), None, &mut entries + ).await; + let mut waiting_tokens = 1; - // We loop until we do not receive any cached batch from the inference server (== until - // all requests have met their stopping criteria) - while let Some(batch) = cached_batch { - // Get current batch info - let batch_size = batch.size; - let mut request_ids: Vec = batch.requests.iter().map(|req| req.id).collect(); - let mut batches = vec![batch]; + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let mut batches = vec![batch]; - // If the current batch is too small, we try to add more requests to it - if batch_size <= limit_min_batch_size { - let min_size = match waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - _ if waiting_tokens >= max_waiting_tokens => None, - // Minimum size criteria - _ => Some(limit_min_batch_size as usize), - }; + // If the current batch is too small, we try to add more requests to it + if batch_size <= limit_min_batch_size { + let min_size = match waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + _ if waiting_tokens >= max_waiting_tokens => 1, + // Minimum size criteria + _ => limit_min_batch_size as usize, + }; - // Try to get a new batch - if let Some((new_request_ids, new_batch)) = - db.next_batch(min_size, max_batch_size - batch_size as usize) - { - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - wrap_future(client.generate(new_batch), new_request_ids, &db).await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { - request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); - batches.push(new_cached_batch); - } + // Try to get a new batch + if let Some(new_batch) = queue.try_next_batch( + min_size, max_batch_size - batch_size as usize, &mut entries + ) { + let first_new_id = new_batch.requests.first() + .expect("batch can't be empty here").id; + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = wrap_future( + client.generate(new_batch), Some(first_new_id), &mut entries + ).await; + + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + batches.push(new_cached_batch); } } - - cached_batch = - wrap_future(client.generate_with_cache(batches), request_ids, &db).await; - waiting_tokens += 1; } + + cached_batch = wrap_future( + client.generate_with_cache(batches), None, &mut entries + ).await; + waiting_tokens += 1; } } } @@ -154,39 +155,45 @@ async fn batching_task( /// Wrap a future inside a match statement to handle errors and send the response to the Batcher async fn wrap_future( future: impl Future, Option), ClientError>>, - request_ids: Vec, - db: &Db, + // First request id in this batch if it doesn't comprise all current entries + start_id: Option, + entries: &mut IntMap, ) -> Option { match future.await { Ok((generated_texts, next_batch)) => { - send_generated(generated_texts, db); + send_generated(generated_texts, entries); next_batch } // If we have an error, we discard the whole batch Err(err) => { - send_error(err, request_ids, db); + send_error(err, start_id, entries); None } } } -/// Send errors to the Batcher for all `request_ids` -fn send_error(error: ClientError, request_ids: Vec, db: &Db) { - request_ids.into_iter().for_each(|id| { - // We can `expect` here as the request id should always be in the DB - let entry = db.remove(&id).expect("ID not found in db. This is a bug."); - // unwrap_or is valid here as we don't care if the receiver is gone. - entry.response_tx.send(Err(error.clone())).unwrap_or(()); - }); +/// Send errors to the Batcher for all failed entries +fn send_error(error: ClientError, start_id: Option, entries: &mut IntMap) { + let to_keep = entries.drain().filter_map(|(id, entry)| match start_id { + // Keep entries that weren't in the failed request batch + Some(sid) if id < sid => Some((id, entry)), + _ => { + // unwrap_or is valid here as we don't care if the receiver is gone. + entry.response_tx.send(Err(error.clone())).unwrap_or(()); + None + } + }).collect::>(); + // Workaround since drain_filter() is not yet stable. This will be empty when start_id == None. + entries.extend(to_keep); } /// Send `generated_text` to the Batcher for all `finished` -fn send_generated(finished: Vec, db: &Db) { +fn send_generated(finished: Vec, entries: &mut IntMap) { finished.into_iter().for_each(|output| { - // We can `expect` here as the request id should always be in the DB - let entry = db + // We can `expect` here as the request id should always be in the map + let entry = entries .remove(&output.request.unwrap().id) - .expect("ID not found in db. This is a bug."); + .expect("ID not found. This is a bug."); let response = InferResponse { output_text: output.output_text, diff --git a/router/src/db.rs b/router/src/db.rs deleted file mode 100644 index 1d7df627..00000000 --- a/router/src/db.rs +++ /dev/null @@ -1,179 +0,0 @@ -use crate::InferResponse; -/// This code is massively inspired by Tokio mini-redis -use crate::{GenerateParameters, GenerateRequest}; -use parking_lot::Mutex; -use std::collections::BTreeMap; -use std::sync::Arc; -use text_generation_client::{ - Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, -}; -use tokio::sync::oneshot::Sender; -use tokio::time::Instant; - -/// Database entry -#[derive(Debug)] -pub(crate) struct Entry { - /// Request - pub request: GenerateRequest, - /// Response sender to communicate between the Batcher and the batching_task - pub response_tx: Sender>, - /// Number of tokens in the input - pub input_length: usize, - /// Instant when this entry was created - pub time: Instant, - /// Instant when this entry was added to a batch - pub batch_time: Option, -} - -/// Request Database -#[derive(Debug, Clone)] -pub(crate) struct Db { - pub shared: Arc, -} - -/// Shared state -#[derive(Debug)] -pub struct Shared { - state: Mutex, -} - -/// Database State -#[derive(Debug)] -struct State { - /// Database entries organized in a BTreeMap to be able to iterate over them in order - entries: BTreeMap, - - /// Id of the next entry - next_id: u64, - - /// Id of the next batch - next_batch_id: u64, - - /// Start ID of the next batch. Used to iterate inside the entries BTreeMap - next_batch_start_id: u64, -} - -impl State { - /// Get the next requests - fn next_requests(&self, max_size: usize) -> Option<(Vec, Vec)> { - // Iterates for max_size over the BTreemap starting from next_batch_start_id - let mut requests = Vec::new(); - let mut ids = Vec::new(); - - for (id, entry) in self - .entries - // Start from next_batch_start_id - .range(self.next_batch_start_id..) - // Take max_size - .take(max_size) - { - requests.push(Request { - id: *id, - inputs: entry.request.inputs.clone(), - input_length: entry.input_length as u32, - parameters: Some((&entry.request.parameters).into()), - stopping_parameters: Some(entry.request.parameters.clone().into()), - }); - - ids.push(*id); - } - - if requests.is_empty() { - None - } else { - Some((ids, requests)) - } - } -} - -impl Db { - pub(crate) fn new() -> Self { - // Shared state - let shared = Arc::new(Shared { - state: Mutex::new(State { - entries: BTreeMap::new(), - next_id: 0, - next_batch_id: 0, - next_batch_start_id: 0, - }), - }); - - Self { shared } - } - - /// Append an entry to the database - pub(crate) fn append(&self, entry: Entry) { - // Acquire lock - let mut state = self.shared.state.lock(); - - // Insert entry - let id = state.next_id; - state.next_id += 1; - state.entries.insert(id, entry); - } - - /// Remove an entry from the database if it exists - pub(crate) fn remove(&self, id: &u64) -> Option { - let mut state = self.shared.state.lock(); - state.entries.remove(id) - } - - // Get the next batch - pub(crate) fn next_batch( - &self, - min_size: Option, - max_size: usize, - ) -> Option<(Vec, Batch)> { - // Acquire lock - let mut state = self.shared.state.lock(); - - // Get requests from the database - if let Some((ids, requests)) = state.next_requests(max_size) { - if let Some(min_size) = min_size { - // If min_size is set, only return a batch if there are enough requests - if requests.len() < min_size { - return None; - } - } - ids.iter().for_each(|id| { - // Set batch_time for each request - state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now()); - }); - - // Batch size - let size = requests.len(); - let batch = Batch { - id: state.next_batch_id, - requests, - size: size as u32, - }; - // Update next_batch_start_id to the last id in the batch + 1 - state.next_batch_start_id = ids.last().unwrap() + 1; - // Increment batch id - state.next_batch_id += 1; - - return Some((ids, batch)); - } - None - } -} - -impl From<&GenerateParameters> for NextTokenChooserParameters { - fn from(parameters: &GenerateParameters) -> Self { - Self { - temperature: parameters.temperature, - top_k: parameters.top_k as u32, - top_p: parameters.top_p, - do_sample: parameters.do_sample, - } - } -} - -impl From for StoppingCriteriaParameters { - fn from(parameters: GenerateParameters) -> Self { - Self { - stop_sequences: parameters.stop, - max_new_tokens: parameters.max_new_tokens, - } - } -} diff --git a/router/src/lib.rs b/router/src/lib.rs index 03711580..d4edcab3 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,11 +1,11 @@ /// Text Generation Inference Webserver mod batcher; -mod db; +mod queue; pub mod server; mod validation; use batcher::{Batcher, InferResponse}; -use db::{Db, Entry}; +use queue::Entry; use serde::{Deserialize, Serialize}; use validation::Validation; diff --git a/router/src/queue.rs b/router/src/queue.rs new file mode 100644 index 00000000..0e323799 --- /dev/null +++ b/router/src/queue.rs @@ -0,0 +1,137 @@ +use std::cmp::min; +use crate::InferResponse; +use crate::{GenerateParameters, GenerateRequest}; +use std::collections::VecDeque; +use nohash_hasher::IntMap; +use tokio::sync::mpsc::Receiver; +use text_generation_client::{ + Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +use tokio::sync::oneshot::Sender; +use tokio::time::Instant; + +/// In-flight request record +#[derive(Debug)] +pub(crate) struct Entry { + /// Request + pub request: GenerateRequest, + /// Response sender to communicate between the Batcher and the batching_task + pub response_tx: Sender>, + /// Number of tokens in the input + pub input_length: usize, + /// Instant when this entry was created + pub time: Instant, + /// Instant when this entry was added to a batch + pub batch_time: Option, +} + +/// Request Queue +#[derive(Debug)] +pub(crate) struct Queue { + receiver: Receiver, + buffer: VecDeque, + /// Id of the next entry + next_id: u64, + /// Id of the next batch + next_batch_id: u64, +} + + +impl Queue { + pub(crate) fn new(receiver: Receiver) -> Self { + Self { receiver, buffer: VecDeque::new(), next_id: 0, next_batch_id: 0 } + } + + /// Get the next batch, blocking until available + /// Corresponding entries are added to the entries map + pub(crate) async fn next_batch( + &mut self, + max_size: usize, + entries: &mut IntMap, + ) -> Option { + loop { + if self.buffer.is_empty() { + match self.receiver.recv().await { + Some(ent) => self.buffer.push_back(ent), + None => return None, + } + } + if let Some(batch) = self.try_next_batch(1, max_size, entries) { + return Some(batch) + } + } + } + + /// Get the next batch without blocking + /// Corresponding entries are added to the entries map + pub(crate) fn try_next_batch( + &mut self, + min_size: usize, + max_size: usize, + entries: &mut IntMap, + ) -> Option { + while self.buffer.len() < max_size { + match self.receiver.try_recv() { + Ok(ent) => self.buffer.push_back(ent), + _ => break, + } + } + + let len = self.buffer.len(); + if len < min_size || len == 0 { + // Can't get minimum + return None; + } + + let now = Some(Instant::now()); + let requests = self.buffer.drain(..min(len, max_size)) + .map(|mut entry| { + let id = self.next_id; + self.next_id += 1; + let request = Request { + id, + inputs: entry.request.inputs.clone(), + input_length: entry.input_length as u32, + parameters: Some((&entry.request.parameters).into()), + stopping_parameters: Some(entry.request.parameters.clone().into()), + }; + entry.batch_time = now; + entries.insert(id, entry); + request + }) + .collect::>(); + + // Batch size + let size = requests.len(); + let batch = Batch { + id: self.next_batch_id, + requests, + size: size as u32, + }; + // Increment batch id + self.next_batch_id += 1; + + Some(batch) + } +} + + +impl From<&GenerateParameters> for NextTokenChooserParameters { + fn from(parameters: &GenerateParameters) -> Self { + Self { + temperature: parameters.temperature, + top_k: parameters.top_k as u32, + top_p: parameters.top_p, + do_sample: parameters.do_sample, + } + } +} + +impl From for StoppingCriteriaParameters { + fn from(parameters: GenerateParameters) -> Self { + Self { + stop_sequences: parameters.stop, + max_new_tokens: parameters.max_new_tokens, + } + } +} diff --git a/router/src/server.rs b/router/src/server.rs index 2e6c473f..7607ac43 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -7,11 +7,9 @@ use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::{Json, Router}; use std::net::SocketAddr; -use std::sync::Arc; use text_generation_client::ShardedClient; use tokenizers::Tokenizer; use tokio::signal; -use tokio::sync::Semaphore; use tokio::time::Instant; use tracing::instrument; @@ -20,7 +18,6 @@ use tracing::instrument; struct ServerState { validation: Validation, batcher: Batcher, - limit_concurrent_requests: Arc, } /// Health check method @@ -30,8 +27,8 @@ async fn health(state: Extension) -> Result<(), (StatusCode, Json) -> Result<(), (StatusCode, Json, ) -> Result)> { let start_time = Instant::now(); - // Limit concurrent requests by acquiring a permit from the semaphore - let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { + // Limit concurrent requests by reserving a slot in the queue + let sender = state.batcher.reserve_slot().map_err(|_| { tracing::error!("Model is overloaded"); ( StatusCode::TOO_MANY_REQUESTS, @@ -98,8 +93,7 @@ async fn generate( })?; // Inference - let response = state - .batcher + let response = sender .infer(input_length, validated_request) .await .map_err(|err| { @@ -185,12 +179,13 @@ pub async fn run( addr: SocketAddr, ) { // Create state - let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens); + let batcher = Batcher::new( + client, max_batch_size, max_waiting_tokens, max_concurrent_requests + ); let validation = Validation::new(validation_workers, tokenizer, max_input_length); let shared_state = ServerState { validation, batcher, - limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)), }; // Create router