diff --git a/router/src/infer.rs b/router/src/infer.rs index ac4b6a92..400be55e 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,7 +1,7 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; use crate::GenerateRequest; -use crate::{Db, Entry, Token}; +use crate::{Queue, Entry, Token}; use nohash_hasher::IntMap; use std::future::Future; use std::sync::Arc; @@ -20,8 +20,8 @@ use tracing::instrument; pub struct Infer { /// Validation validation: Validation, - /// Request database - db: Db, + /// Request queue + queue: Queue, /// Shared state shared: Arc, /// Inference limit @@ -43,7 +43,7 @@ impl Infer { max_concurrent_requests: usize, ) -> Self { // Infer shared state - let db = Db::new(); + let queue = Queue::new(); let shared = Arc::new(Shared { batching_task: Notify::new(), }); @@ -53,7 +53,7 @@ impl Infer { client, max_batch_size, max_waiting_tokens, - db.clone(), + queue.clone(), shared.clone(), )); @@ -62,13 +62,13 @@ impl Infer { Self { validation, - db, + queue, shared, limit_concurrent_requests: semaphore, } } - /// Add a new request to the database and return a stream of InferStreamResponse + /// Add a new request to the queue and return a stream of InferStreamResponse pub(crate) async fn generate_stream( &self, request: GenerateRequest, @@ -83,8 +83,8 @@ impl Infer { // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); - // Append the request to the database - self.db.append(Entry { + // Append the request to the queue + self.queue.append(Entry { request: valid_request, response_tx, time: Instant::now(), @@ -92,7 +92,7 @@ impl Infer { _permit: permit, }); - // Notify the background task that we have a new entry in the database that needs + // Notify the background task that we have a new entry in the queue that needs // to be batched self.shared.batching_task.notify_one(); @@ -100,7 +100,7 @@ impl Infer { Ok(UnboundedReceiverStream::new(response_rx)) } - /// Add a new request to the database and return a InferResponse + /// Add a new request to the queue and return a InferResponse pub(crate) async fn generate( &self, request: GenerateRequest, @@ -169,12 +169,12 @@ impl Infer { /// Will be launched in a background Tokio task /// /// Batches requests and sends them to the inference server -#[instrument(skip(client, db, shared))] +#[instrument(skip(client, queue, shared))] async fn batching_task( mut client: ShardedClient, max_batch_size: usize, max_waiting_tokens: usize, - db: Db, + queue: Queue, shared: Arc, ) { // Minimum batch size after which we try to add more requests @@ -185,10 +185,10 @@ async fn batching_task( // Wait for a notification from the Infer struct shared.batching_task.notified().await; - // Get the next batch from the DB + // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests - // waiting in the DB - while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size).await { + // waiting in the queue + while let Some((mut entries, batch)) = queue.next_batch(None, max_batch_size).await { let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await; let mut waiting_tokens = 1; @@ -210,7 +210,7 @@ async fn batching_task( }; // Try to get a new batch - if let Some((mut new_entries, new_batch)) = db + if let Some((mut new_entries, new_batch)) = queue .next_batch(min_size, max_batch_size - batch_size as usize) .await { diff --git a/router/src/lib.rs b/router/src/lib.rs index 5b96485f..e28634e5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,10 +1,10 @@ /// Text Generation Inference Webserver -mod db; +mod queue; mod infer; pub mod server; mod validation; -use db::{Db, Entry}; +use queue::{Queue, Entry}; use infer::Infer; use serde::{Deserialize, Serialize}; use validation::Validation; diff --git a/router/src/db.rs b/router/src/queue.rs similarity index 81% rename from router/src/db.rs rename to router/src/queue.rs index 2606abd7..2aaf93b1 100644 --- a/router/src/db.rs +++ b/router/src/queue.rs @@ -1,4 +1,3 @@ -/// This code is massively inspired by Tokio mini-redis use crate::infer::InferError; use crate::infer::InferStreamResponse; use crate::validation::ValidGenerateRequest; @@ -9,7 +8,7 @@ use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; use tokio::time::Instant; -/// Database entry +/// Queue entry #[derive(Debug)] pub(crate) struct Entry { /// Request @@ -24,29 +23,29 @@ pub(crate) struct Entry { pub _permit: OwnedSemaphorePermit, } -/// Request Database +/// Request Queue #[derive(Debug, Clone)] -pub(crate) struct Db { - /// Channel to communicate with the background database task - db_sender: UnboundedSender, +pub(crate) struct Queue { + /// Channel to communicate with the background queue task + queue_sender: UnboundedSender, } -impl Db { +impl Queue { pub(crate) fn new() -> Self { // Create channel - let (db_sender, db_receiver) = mpsc::unbounded_channel(); + let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); - // Launch background database task - tokio::spawn(database_task(db_receiver)); + // Launch background queue task + tokio::spawn(queue_task(queue_receiver)); - Self { db_sender } + Self { queue_sender } } - /// Append an entry to the database + /// Append an entry to the queue pub(crate) fn append(&self, entry: Entry) { // Send append command to the background task managing the state // Unwrap is safe here - self.db_sender.send(DatabaseCommand::Append(entry)).unwrap(); + self.queue_sender.send(QueueCommand::Append(entry)).unwrap(); } // Get the next batch @@ -59,8 +58,8 @@ impl Db { let (response_sender, response_receiver) = oneshot::channel(); // Send next batch command to the background task managing the state // Unwrap is safe here - self.db_sender - .send(DatabaseCommand::NextBatch { + self.queue_sender + .send(QueueCommand::NextBatch { min_size, max_size, response_sender, @@ -72,14 +71,14 @@ impl Db { } } -// Background task responsible of the database state -async fn database_task(mut receiver: UnboundedReceiver) { +// Background task responsible of the queue state +async fn queue_task(mut receiver: UnboundedReceiver) { let mut state = State::new(); while let Some(cmd) = receiver.recv().await { match cmd { - DatabaseCommand::Append(entry) => state.append(entry), - DatabaseCommand::NextBatch { + QueueCommand::Append(entry) => state.append(entry), + QueueCommand::NextBatch { min_size, max_size, response_sender, @@ -91,10 +90,10 @@ async fn database_task(mut receiver: UnboundedReceiver) { } } -/// Database State +/// Queue State #[derive(Debug)] struct State { - /// Database entries organized in a Vec + /// Queue entries organized in a Vec entries: Vec<(u64, Entry)>, /// Id of the next entry @@ -113,7 +112,7 @@ impl State { } } - /// Append an entry to the database + /// Append an entry to the queue fn append(&mut self, entry: Entry) { self.entries.push((self.next_id, entry)); self.next_id += 1; @@ -125,7 +124,7 @@ impl State { return None; } - // Check if we have enough entries in DB + // Check if we have enough entries if let Some(min_size) = min_size { if self.entries.len() < min_size { return None; @@ -170,7 +169,7 @@ impl State { type NextBatch = (IntMap, Batch); #[derive(Debug)] -enum DatabaseCommand { +enum QueueCommand { Append(Entry), NextBatch { min_size: Option, @@ -299,26 +298,26 @@ mod tests { } #[tokio::test] - async fn test_db_append() { - let db = Db::new(); - db.append(default_entry()); + async fn test_queue_append() { + let queue = Queue::new(); + queue.append(default_entry()); } #[tokio::test] - async fn test_db_next_batch_empty() { - let db = Db::new(); + async fn test_queue_next_batch_empty() { + let queue = Queue::new(); - assert!(db.next_batch(None, 1).await.is_none()); - assert!(db.next_batch(Some(1), 1).await.is_none()); + assert!(queue.next_batch(None, 1).await.is_none()); + assert!(queue.next_batch(Some(1), 1).await.is_none()); } #[tokio::test] - async fn test_db_next_batch_min_size() { - let db = Db::new(); - db.append(default_entry()); - db.append(default_entry()); + async fn test_queue_next_batch_min_size() { + let queue = Queue::new(); + queue.append(default_entry()); + queue.append(default_entry()); - let (entries, batch) = db.next_batch(None, 2).await.unwrap(); + let (entries, batch) = queue.next_batch(None, 2).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -327,26 +326,26 @@ mod tests { assert_eq!(batch.id, 0); assert_eq!(batch.size, 2); - db.append(default_entry()); + queue.append(default_entry()); - assert!(db.next_batch(Some(2), 2).await.is_none()); + assert!(queue.next_batch(Some(2), 2).await.is_none()); } #[tokio::test] - async fn test_db_next_batch_max_size() { - let db = Db::new(); - db.append(default_entry()); - db.append(default_entry()); + async fn test_queue_next_batch_max_size() { + let queue = Queue::new(); + queue.append(default_entry()); + queue.append(default_entry()); - let (entries, batch) = db.next_batch(None, 1).await.unwrap(); + let (entries, batch) = queue.next_batch(None, 1).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); assert_eq!(batch.size, 1); - db.append(default_entry()); + queue.append(default_entry()); - let (entries, batch) = db.next_batch(None, 3).await.unwrap(); + let (entries, batch) = queue.next_batch(None, 3).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2));