From c863f05cfd59f1ab6d68fc2c8e29f333286950a1 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 2 Feb 2023 12:54:56 +0100 Subject: [PATCH] feat(router): rework db to use a background task --- launcher/src/main.rs | 5 +- router/src/db.rs | 398 +++++++++++++++++++++++++++++---------- router/src/infer.rs | 7 +- router/src/validation.rs | 2 +- 4 files changed, 306 insertions(+), 106 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 20ec7faa..dea6fcc8 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -316,7 +316,10 @@ fn shard_manager( // If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard // Useful when running inside a HuggingFace Inference Endpoint if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") { - env.push(("WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into())); + env.push(( + "WEIGHTS_CACHE_OVERRIDE".into(), + weights_cache_override.into(), + )); }; // If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard diff --git a/router/src/db.rs b/router/src/db.rs index 246e4d5d..1c7f7399 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -3,12 +3,10 @@ use crate::infer::InferError; use crate::infer::InferStreamResponse; use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use parking_lot::Mutex; -use std::collections::BTreeMap; -use std::sync::Arc; +use std::cmp::min; use text_generation_client::{Batch, Request}; -use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::OwnedSemaphorePermit; +use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; +use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; use tokio::time::Instant; /// Database entry @@ -29,132 +27,330 @@ pub(crate) struct Entry { /// Request Database #[derive(Debug, Clone)] pub(crate) struct Db { - pub shared: Arc, + /// Channel to communicate with the background database task + sender: UnboundedSender, } -/// Shared state -#[derive(Debug)] -pub struct Shared { - state: Mutex, +impl Db { + pub(crate) fn new() -> Self { + // Create channel + let (db_sender, db_receiver) = mpsc::unbounded_channel(); + + // Launch background database task + tokio::spawn(database_task(db_receiver)); + + Self { sender: db_sender } + } + + /// Append an entry to the database + pub(crate) fn append(&self, entry: Entry) { + // Send append command to the background task managing the state + // Unwrap is safe here + self.sender.send(DatabaseCommand::Append(entry)).unwrap(); + } + + // Get the next batch + pub(crate) async fn next_batch( + &self, + min_size: Option, + max_size: usize, + ) -> Option { + // Create response channel + let (sender, receiver) = oneshot::channel(); + // Send next batch command to the background task managing the state + // Unwrap is safe here + self.sender + .send(DatabaseCommand::NextBatch { + min_size, + max_size, + response_rx: sender, + }) + .unwrap(); + // Await on response channel + // Unwrap is safe here + receiver.await.unwrap() + } +} + +// Background task responsible of the database state +async fn database_task(mut receiver: UnboundedReceiver) { + let mut state = State::new(); + + while let Some(cmd) = receiver.recv().await { + match cmd { + DatabaseCommand::Append(entry) => state.append(entry), + DatabaseCommand::NextBatch { + min_size, + max_size, + response_rx, + } => { + let next_batch = state.next_batch(min_size, max_size); + response_rx.send(next_batch).unwrap_or(()); + } + } + } } /// Database State #[derive(Debug)] struct State { - /// Database entries organized in a BTreeMap to be able to iterate over them in order - entries: BTreeMap, + /// Database entries organized in a Vec + entries: Vec<(u64, Entry)>, /// Id of the next entry next_id: u64, /// Id of the next batch next_batch_id: u64, - - /// Start ID of the next batch. Used to iterate inside the entries BTreeMap - next_batch_start_id: u64, } impl State { - /// Get the next requests - fn next_requests(&self, max_size: usize) -> Option<(Vec, Vec)> { - // Iterates for max_size over the BTreemap starting from next_batch_start_id - let mut requests = Vec::new(); - let mut ids = Vec::new(); - - for (id, entry) in self - .entries - // Start from next_batch_start_id - .range(self.next_batch_start_id..) - // Take max_size - .take(max_size) - { - requests.push(Request { - id: *id, - inputs: entry.request.inputs.clone(), - input_length: entry.request.input_length, - parameters: Some(entry.request.parameters.clone()), - stopping_parameters: Some(entry.request.stopping_parameters.clone()), - }); - - ids.push(*id); + fn new() -> Self { + Self { + entries: Vec::with_capacity(128), + next_id: 0, + next_batch_id: 0, } - - if requests.is_empty() { - None - } else { - Some((ids, requests)) - } - } -} - -impl Db { - pub(crate) fn new() -> Self { - // Shared state - let shared = Arc::new(Shared { - state: Mutex::new(State { - entries: BTreeMap::new(), - next_id: 0, - next_batch_id: 0, - next_batch_start_id: 0, - }), - }); - - Self { shared } } /// Append an entry to the database - pub(crate) fn append(&self, entry: Entry) { - // Acquire lock - let mut state = self.shared.state.lock(); - - // Insert entry - let id = state.next_id; - state.next_id += 1; - state.entries.insert(id, entry); + fn append(&mut self, entry: Entry) { + self.entries.push((self.next_id, entry)); + self.next_id += 1; } // Get the next batch - pub(crate) fn next_batch( - &self, - min_size: Option, - max_size: usize, - ) -> Option<(IntMap, Batch)> { - // Acquire lock - let mut state = self.shared.state.lock(); - - // Get requests from the database - if let Some((ids, requests)) = state.next_requests(max_size) { - if let Some(min_size) = min_size { - // If min_size is set, only return a batch if there are enough requests - if requests.len() < min_size { - return None; - } + fn next_batch(&mut self, min_size: Option, max_size: usize) -> Option { + // Check if we have enough entries in DB by comparing next batch id and current id + if let Some(min_size) = min_size { + if self.entries.len() < min_size { + return None; } - // Batch size - let size = requests.len(); + } - let mut entries = IntMap::with_capacity_and_hasher(size, BuildNoHashHasher::default()); - ids.iter().for_each(|id| { - // Remove entry from db - let mut entry = state.entries.remove(id).unwrap(); + // If both ids are equal, the DB is empty + if self.entries.is_empty() { + return None; + } + + let next_batch_size = min(self.entries.len(), max_size); + + // Iterates for max_size over the BTreemap starting from next_batch_start_id + let mut batch_requests = Vec::new(); + let mut batch_entries = + IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default()); + + self.entries + .drain(..next_batch_size) + .for_each(|(id, mut entry)| { + batch_requests.push(Request { + id, + inputs: entry.request.inputs.clone(), + input_length: entry.request.input_length, + parameters: Some(entry.request.parameters.clone()), + stopping_parameters: Some(entry.request.stopping_parameters.clone()), + }); // Set batch_time entry.batch_time = Some(Instant::now()); // Insert in entries IntMap - entries.insert(*id, entry); + batch_entries.insert(id, entry); }); - let batch = Batch { - id: state.next_batch_id, - requests, - size: size as u32, - }; - // Update next_batch_start_id to the last id in the batch + 1 - state.next_batch_start_id = ids.last().unwrap() + 1; - // Increment batch id - state.next_batch_id += 1; + let batch = Batch { + id: self.next_batch_id, + requests: batch_requests, + size: next_batch_size as u32, + }; + // Increment batch id + self.next_batch_id += 1; - return Some((entries, batch)); - } - None + Some((batch_entries, batch)) + } +} + +type NextBatch = (IntMap, Batch); + +#[derive(Debug)] +enum DatabaseCommand { + Append(Entry), + NextBatch { + min_size: Option, + max_size: usize, + response_rx: oneshot::Sender>, + }, +} + +#[cfg(test)] +mod tests { + use super::*; + use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; + use tokio::sync::{mpsc, Semaphore}; + + fn default_entry() -> Entry { + let semaphore = Arc::new(Semaphore::new(1)); + let (response_tx, _) = mpsc::unbounded_channel(); + let permit = semaphore.try_acquire_owned().unwrap(); + + Entry { + request: ValidGenerateRequest { + inputs: "".to_string(), + input_length: 0, + parameters: NextTokenChooserParameters { + temperature: 0.0, + top_k: 0, + top_p: 0.0, + do_sample: false, + seed: 0, + repetition_penalty: 0.0, + }, + stopping_parameters: StoppingCriteriaParameters { + max_new_tokens: 0, + stop_sequences: vec![], + }, + }, + response_tx, + time: Instant::now(), + batch_time: None, + _permit: permit, + } + } + + #[test] + fn test_append() { + let mut state = State::new(); + let entry = default_entry(); + + assert_eq!(state.next_id, 0); + assert_eq!(state.entries.len(), 0); + + state.append(entry); + + assert_eq!(state.next_id, 1); + assert_eq!(state.entries.len(), 1); + let (id, _) = state.entries.remove(0); + assert_eq!(id, 0); + } + + #[test] + fn test_next_batch_empty() { + let mut state = State::new(); + + assert!(state.next_batch(None, 1).is_none()); + assert!(state.next_batch(Some(1), 1).is_none()); + } + + #[test] + fn test_next_batch_min_size() { + let mut state = State::new(); + state.append(default_entry()); + state.append(default_entry()); + + let (entries, batch) = state.next_batch(None, 2).unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&0)); + assert!(entries.contains_key(&1)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert!(entries.get(&1).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 2); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 0); + assert_eq!(state.next_batch_id, 1); + + state.append(default_entry()); + + assert!(state.next_batch(Some(2), 2).is_none()); + + assert_eq!(state.next_id, 3); + assert_eq!(state.entries.len(), 1); + let (id, _) = state.entries.remove(0); + assert_eq!(id, 2); + } + + #[test] + fn test_next_batch_max_size() { + let mut state = State::new(); + state.append(default_entry()); + state.append(default_entry()); + + let (entries, batch) = state.next_batch(None, 1).unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 1); + assert_eq!(state.next_batch_id, 1); + + state.append(default_entry()); + + let (entries, batch) = state.next_batch(None, 3).unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&1)); + assert!(entries.contains_key(&2)); + assert_eq!(batch.id, 1); + assert_eq!(batch.size, 2); + + assert_eq!(state.next_id, 3); + assert_eq!(state.entries.len(), 0); + assert_eq!(state.next_batch_id, 2); + } + + #[tokio::test] + async fn test_db_append() { + let db = Db::new(); + db.append(default_entry()); + } + + #[tokio::test] + async fn test_db_next_batch_empty() { + let db = Db::new(); + + assert!(db.next_batch(None, 1).await.is_none()); + assert!(db.next_batch(Some(1), 1).await.is_none()); + } + + #[tokio::test] + async fn test_db_next_batch_min_size() { + let db = Db::new(); + db.append(default_entry()); + db.append(default_entry()); + + let (entries, batch) = db.next_batch(None, 2).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&0)); + assert!(entries.contains_key(&1)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert!(entries.get(&1).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 2); + + db.append(default_entry()); + + assert!(db.next_batch(Some(2), 2).await.is_none()); + } + + #[tokio::test] + async fn test_db_next_batch_max_size() { + let db = Db::new(); + db.append(default_entry()); + db.append(default_entry()); + + let (entries, batch) = db.next_batch(None, 1).await.unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + db.append(default_entry()); + + let (entries, batch) = db.next_batch(None, 3).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&1)); + assert!(entries.contains_key(&2)); + assert_eq!(batch.id, 1); + assert_eq!(batch.size, 2); } } diff --git a/router/src/infer.rs b/router/src/infer.rs index 23e84265..ac4b6a92 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -188,7 +188,7 @@ async fn batching_task( // Get the next batch from the DB // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the DB - while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) { + while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size).await { let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await; let mut waiting_tokens = 1; @@ -210,8 +210,9 @@ async fn batching_task( }; // Try to get a new batch - if let Some((mut new_entries, new_batch)) = - db.next_batch(min_size, max_batch_size - batch_size as usize) + if let Some((mut new_entries, new_batch)) = db + .next_batch(min_size, max_batch_size - batch_size as usize) + .await { // Generate one token for this new batch to have the attention past in cache let new_cached_batch = diff --git a/router/src/validation.rs b/router/src/validation.rs index ddb534d3..09220823 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -19,7 +19,7 @@ pub struct Validation { impl Validation { pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self { - // Crate channel + // Create channel let (validation_sender, validation_receiver) = mpsc::channel(128); // Launch background validation task