From 7b870e1e1837654bfc268748138ea81402c6afc7 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 2 Feb 2023 14:59:27 +0100 Subject: [PATCH] feat(router): use background task to manage request queue (#52) Co-authored-by: Nick Hill --- launcher/src/main.rs | 5 +- router/src/db.rs | 160 ------------------ router/src/infer.rs | 37 ++-- router/src/lib.rs | 6 +- router/src/queue.rs | 355 +++++++++++++++++++++++++++++++++++++++ router/src/validation.rs | 2 +- 6 files changed, 382 insertions(+), 183 deletions(-) delete mode 100644 router/src/db.rs create mode 100644 router/src/queue.rs diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 20ec7faa..dea6fcc8 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -316,7 +316,10 @@ fn shard_manager( // If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard // Useful when running inside a HuggingFace Inference Endpoint if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") { - env.push(("WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into())); + env.push(( + "WEIGHTS_CACHE_OVERRIDE".into(), + weights_cache_override.into(), + )); }; // If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard diff --git a/router/src/db.rs b/router/src/db.rs deleted file mode 100644 index 246e4d5d..00000000 --- a/router/src/db.rs +++ /dev/null @@ -1,160 +0,0 @@ -/// This code is massively inspired by Tokio mini-redis -use crate::infer::InferError; -use crate::infer::InferStreamResponse; -use crate::validation::ValidGenerateRequest; -use nohash_hasher::{BuildNoHashHasher, IntMap}; -use parking_lot::Mutex; -use std::collections::BTreeMap; -use std::sync::Arc; -use text_generation_client::{Batch, Request}; -use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::OwnedSemaphorePermit; -use tokio::time::Instant; - -/// Database entry -#[derive(Debug)] -pub(crate) struct Entry { - /// Request - pub request: ValidGenerateRequest, - /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: UnboundedSender>, - /// Instant when this entry was created - pub time: Instant, - /// Instant when this entry was added to a batch - pub batch_time: Option, - /// Permit - pub _permit: OwnedSemaphorePermit, -} - -/// Request Database -#[derive(Debug, Clone)] -pub(crate) struct Db { - pub shared: Arc, -} - -/// Shared state -#[derive(Debug)] -pub struct Shared { - state: Mutex, -} - -/// Database State -#[derive(Debug)] -struct State { - /// Database entries organized in a BTreeMap to be able to iterate over them in order - entries: BTreeMap, - - /// Id of the next entry - next_id: u64, - - /// Id of the next batch - next_batch_id: u64, - - /// Start ID of the next batch. Used to iterate inside the entries BTreeMap - next_batch_start_id: u64, -} - -impl State { - /// Get the next requests - fn next_requests(&self, max_size: usize) -> Option<(Vec, Vec)> { - // Iterates for max_size over the BTreemap starting from next_batch_start_id - let mut requests = Vec::new(); - let mut ids = Vec::new(); - - for (id, entry) in self - .entries - // Start from next_batch_start_id - .range(self.next_batch_start_id..) - // Take max_size - .take(max_size) - { - requests.push(Request { - id: *id, - inputs: entry.request.inputs.clone(), - input_length: entry.request.input_length, - parameters: Some(entry.request.parameters.clone()), - stopping_parameters: Some(entry.request.stopping_parameters.clone()), - }); - - ids.push(*id); - } - - if requests.is_empty() { - None - } else { - Some((ids, requests)) - } - } -} - -impl Db { - pub(crate) fn new() -> Self { - // Shared state - let shared = Arc::new(Shared { - state: Mutex::new(State { - entries: BTreeMap::new(), - next_id: 0, - next_batch_id: 0, - next_batch_start_id: 0, - }), - }); - - Self { shared } - } - - /// Append an entry to the database - pub(crate) fn append(&self, entry: Entry) { - // Acquire lock - let mut state = self.shared.state.lock(); - - // Insert entry - let id = state.next_id; - state.next_id += 1; - state.entries.insert(id, entry); - } - - // Get the next batch - pub(crate) fn next_batch( - &self, - min_size: Option, - max_size: usize, - ) -> Option<(IntMap, Batch)> { - // Acquire lock - let mut state = self.shared.state.lock(); - - // Get requests from the database - if let Some((ids, requests)) = state.next_requests(max_size) { - if let Some(min_size) = min_size { - // If min_size is set, only return a batch if there are enough requests - if requests.len() < min_size { - return None; - } - } - // Batch size - let size = requests.len(); - - let mut entries = IntMap::with_capacity_and_hasher(size, BuildNoHashHasher::default()); - ids.iter().for_each(|id| { - // Remove entry from db - let mut entry = state.entries.remove(id).unwrap(); - // Set batch_time - entry.batch_time = Some(Instant::now()); - // Insert in entries IntMap - entries.insert(*id, entry); - }); - - let batch = Batch { - id: state.next_batch_id, - requests, - size: size as u32, - }; - // Update next_batch_start_id to the last id in the batch + 1 - state.next_batch_start_id = ids.last().unwrap() + 1; - // Increment batch id - state.next_batch_id += 1; - - return Some((entries, batch)); - } - None - } -} diff --git a/router/src/infer.rs b/router/src/infer.rs index 23e84265..3661b0e0 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,7 +1,7 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; use crate::GenerateRequest; -use crate::{Db, Entry, Token}; +use crate::{Entry, Queue, Token}; use nohash_hasher::IntMap; use std::future::Future; use std::sync::Arc; @@ -20,8 +20,8 @@ use tracing::instrument; pub struct Infer { /// Validation validation: Validation, - /// Request database - db: Db, + /// Request queue + queue: Queue, /// Shared state shared: Arc, /// Inference limit @@ -43,7 +43,7 @@ impl Infer { max_concurrent_requests: usize, ) -> Self { // Infer shared state - let db = Db::new(); + let queue = Queue::new(); let shared = Arc::new(Shared { batching_task: Notify::new(), }); @@ -53,7 +53,7 @@ impl Infer { client, max_batch_size, max_waiting_tokens, - db.clone(), + queue.clone(), shared.clone(), )); @@ -62,13 +62,13 @@ impl Infer { Self { validation, - db, + queue, shared, limit_concurrent_requests: semaphore, } } - /// Add a new request to the database and return a stream of InferStreamResponse + /// Add a new request to the queue and return a stream of InferStreamResponse pub(crate) async fn generate_stream( &self, request: GenerateRequest, @@ -83,8 +83,8 @@ impl Infer { // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); - // Append the request to the database - self.db.append(Entry { + // Append the request to the queue + self.queue.append(Entry { request: valid_request, response_tx, time: Instant::now(), @@ -92,7 +92,7 @@ impl Infer { _permit: permit, }); - // Notify the background task that we have a new entry in the database that needs + // Notify the background task that we have a new entry in the queue that needs // to be batched self.shared.batching_task.notify_one(); @@ -100,7 +100,7 @@ impl Infer { Ok(UnboundedReceiverStream::new(response_rx)) } - /// Add a new request to the database and return a InferResponse + /// Add a new request to the queue and return a InferResponse pub(crate) async fn generate( &self, request: GenerateRequest, @@ -169,12 +169,12 @@ impl Infer { /// Will be launched in a background Tokio task /// /// Batches requests and sends them to the inference server -#[instrument(skip(client, db, shared))] +#[instrument(skip(client, queue, shared))] async fn batching_task( mut client: ShardedClient, max_batch_size: usize, max_waiting_tokens: usize, - db: Db, + queue: Queue, shared: Arc, ) { // Minimum batch size after which we try to add more requests @@ -185,10 +185,10 @@ async fn batching_task( // Wait for a notification from the Infer struct shared.batching_task.notified().await; - // Get the next batch from the DB + // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests - // waiting in the DB - while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) { + // waiting in the queue + while let Some((mut entries, batch)) = queue.next_batch(None, max_batch_size).await { let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await; let mut waiting_tokens = 1; @@ -210,8 +210,9 @@ async fn batching_task( }; // Try to get a new batch - if let Some((mut new_entries, new_batch)) = - db.next_batch(min_size, max_batch_size - batch_size as usize) + if let Some((mut new_entries, new_batch)) = queue + .next_batch(min_size, max_batch_size - batch_size as usize) + .await { // Generate one token for this new batch to have the attention past in cache let new_cached_batch = diff --git a/router/src/lib.rs b/router/src/lib.rs index 5b96485f..c6ac2022 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,11 +1,11 @@ -/// Text Generation Inference Webserver -mod db; mod infer; +/// Text Generation Inference Webserver +mod queue; pub mod server; mod validation; -use db::{Db, Entry}; use infer::Infer; +use queue::{Entry, Queue}; use serde::{Deserialize, Serialize}; use validation::Validation; diff --git a/router/src/queue.rs b/router/src/queue.rs new file mode 100644 index 00000000..2aaf93b1 --- /dev/null +++ b/router/src/queue.rs @@ -0,0 +1,355 @@ +use crate::infer::InferError; +use crate::infer::InferStreamResponse; +use crate::validation::ValidGenerateRequest; +use nohash_hasher::{BuildNoHashHasher, IntMap}; +use std::cmp::min; +use text_generation_client::{Batch, Request}; +use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; +use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; +use tokio::time::Instant; + +/// Queue entry +#[derive(Debug)] +pub(crate) struct Entry { + /// Request + pub request: ValidGenerateRequest, + /// Response sender to communicate between the Infer struct and the batching_task + pub response_tx: UnboundedSender>, + /// Instant when this entry was created + pub time: Instant, + /// Instant when this entry was added to a batch + pub batch_time: Option, + /// Permit + pub _permit: OwnedSemaphorePermit, +} + +/// Request Queue +#[derive(Debug, Clone)] +pub(crate) struct Queue { + /// Channel to communicate with the background queue task + queue_sender: UnboundedSender, +} + +impl Queue { + pub(crate) fn new() -> Self { + // Create channel + let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); + + // Launch background queue task + tokio::spawn(queue_task(queue_receiver)); + + Self { queue_sender } + } + + /// Append an entry to the queue + pub(crate) fn append(&self, entry: Entry) { + // Send append command to the background task managing the state + // Unwrap is safe here + self.queue_sender.send(QueueCommand::Append(entry)).unwrap(); + } + + // Get the next batch + pub(crate) async fn next_batch( + &self, + min_size: Option, + max_size: usize, + ) -> Option { + // Create response channel + let (response_sender, response_receiver) = oneshot::channel(); + // Send next batch command to the background task managing the state + // Unwrap is safe here + self.queue_sender + .send(QueueCommand::NextBatch { + min_size, + max_size, + response_sender, + }) + .unwrap(); + // Await on response channel + // Unwrap is safe here + response_receiver.await.unwrap() + } +} + +// Background task responsible of the queue state +async fn queue_task(mut receiver: UnboundedReceiver) { + let mut state = State::new(); + + while let Some(cmd) = receiver.recv().await { + match cmd { + QueueCommand::Append(entry) => state.append(entry), + QueueCommand::NextBatch { + min_size, + max_size, + response_sender, + } => { + let next_batch = state.next_batch(min_size, max_size); + response_sender.send(next_batch).unwrap_or(()); + } + } + } +} + +/// Queue State +#[derive(Debug)] +struct State { + /// Queue entries organized in a Vec + entries: Vec<(u64, Entry)>, + + /// Id of the next entry + next_id: u64, + + /// Id of the next batch + next_batch_id: u64, +} + +impl State { + fn new() -> Self { + Self { + entries: Vec::with_capacity(128), + next_id: 0, + next_batch_id: 0, + } + } + + /// Append an entry to the queue + fn append(&mut self, entry: Entry) { + self.entries.push((self.next_id, entry)); + self.next_id += 1; + } + + // Get the next batch + fn next_batch(&mut self, min_size: Option, max_size: usize) -> Option { + if self.entries.is_empty() { + return None; + } + + // Check if we have enough entries + if let Some(min_size) = min_size { + if self.entries.len() < min_size { + return None; + } + } + + let next_batch_size = min(self.entries.len(), max_size); + + let mut batch_requests = Vec::with_capacity(next_batch_size); + let mut batch_entries = + IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default()); + + // Drain next_batch_size entries + self.entries + .drain(..next_batch_size) + .for_each(|(id, mut entry)| { + batch_requests.push(Request { + id, + inputs: entry.request.inputs.clone(), + input_length: entry.request.input_length, + parameters: Some(entry.request.parameters.clone()), + stopping_parameters: Some(entry.request.stopping_parameters.clone()), + }); + // Set batch_time + entry.batch_time = Some(Instant::now()); + // Insert in batch_entries IntMap + batch_entries.insert(id, entry); + }); + + let batch = Batch { + id: self.next_batch_id, + requests: batch_requests, + size: next_batch_size as u32, + }; + // Increment batch id + self.next_batch_id += 1; + + Some((batch_entries, batch)) + } +} + +type NextBatch = (IntMap, Batch); + +#[derive(Debug)] +enum QueueCommand { + Append(Entry), + NextBatch { + min_size: Option, + max_size: usize, + response_sender: oneshot::Sender>, + }, +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; + use tokio::sync::{mpsc, Semaphore}; + + fn default_entry() -> Entry { + let semaphore = Arc::new(Semaphore::new(1)); + let (response_tx, _) = mpsc::unbounded_channel(); + let permit = semaphore.try_acquire_owned().unwrap(); + + Entry { + request: ValidGenerateRequest { + inputs: "".to_string(), + input_length: 0, + parameters: NextTokenChooserParameters { + temperature: 0.0, + top_k: 0, + top_p: 0.0, + do_sample: false, + seed: 0, + repetition_penalty: 0.0, + }, + stopping_parameters: StoppingCriteriaParameters { + max_new_tokens: 0, + stop_sequences: vec![], + }, + }, + response_tx, + time: Instant::now(), + batch_time: None, + _permit: permit, + } + } + + #[test] + fn test_append() { + let mut state = State::new(); + let entry = default_entry(); + + assert_eq!(state.next_id, 0); + assert_eq!(state.entries.len(), 0); + + state.append(entry); + + assert_eq!(state.next_id, 1); + assert_eq!(state.entries.len(), 1); + let (id, _) = state.entries.remove(0); + assert_eq!(id, 0); + } + + #[test] + fn test_next_batch_empty() { + let mut state = State::new(); + + assert!(state.next_batch(None, 1).is_none()); + assert!(state.next_batch(Some(1), 1).is_none()); + } + + #[test] + fn test_next_batch_min_size() { + let mut state = State::new(); + state.append(default_entry()); + state.append(default_entry()); + + let (entries, batch) = state.next_batch(None, 2).unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&0)); + assert!(entries.contains_key(&1)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert!(entries.get(&1).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 2); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 0); + assert_eq!(state.next_batch_id, 1); + + state.append(default_entry()); + + assert!(state.next_batch(Some(2), 2).is_none()); + + assert_eq!(state.next_id, 3); + assert_eq!(state.entries.len(), 1); + let (id, _) = state.entries.remove(0); + assert_eq!(id, 2); + } + + #[test] + fn test_next_batch_max_size() { + let mut state = State::new(); + state.append(default_entry()); + state.append(default_entry()); + + let (entries, batch) = state.next_batch(None, 1).unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 1); + assert_eq!(state.next_batch_id, 1); + + state.append(default_entry()); + + let (entries, batch) = state.next_batch(None, 3).unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&1)); + assert!(entries.contains_key(&2)); + assert_eq!(batch.id, 1); + assert_eq!(batch.size, 2); + + assert_eq!(state.next_id, 3); + assert_eq!(state.entries.len(), 0); + assert_eq!(state.next_batch_id, 2); + } + + #[tokio::test] + async fn test_queue_append() { + let queue = Queue::new(); + queue.append(default_entry()); + } + + #[tokio::test] + async fn test_queue_next_batch_empty() { + let queue = Queue::new(); + + assert!(queue.next_batch(None, 1).await.is_none()); + assert!(queue.next_batch(Some(1), 1).await.is_none()); + } + + #[tokio::test] + async fn test_queue_next_batch_min_size() { + let queue = Queue::new(); + queue.append(default_entry()); + queue.append(default_entry()); + + let (entries, batch) = queue.next_batch(None, 2).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&0)); + assert!(entries.contains_key(&1)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert!(entries.get(&1).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 2); + + queue.append(default_entry()); + + assert!(queue.next_batch(Some(2), 2).await.is_none()); + } + + #[tokio::test] + async fn test_queue_next_batch_max_size() { + let queue = Queue::new(); + queue.append(default_entry()); + queue.append(default_entry()); + + let (entries, batch) = queue.next_batch(None, 1).await.unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + queue.append(default_entry()); + + let (entries, batch) = queue.next_batch(None, 3).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&1)); + assert!(entries.contains_key(&2)); + assert_eq!(batch.id, 1); + assert_eq!(batch.size, 2); + } +} diff --git a/router/src/validation.rs b/router/src/validation.rs index ddb534d3..09220823 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -19,7 +19,7 @@ pub struct Validation { impl Validation { pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self { - // Crate channel + // Create channel let (validation_sender, validation_receiver) = mpsc::channel(128); // Launch background validation task