diff --git a/router/src/db.rs b/router/src/db.rs index 2f2fcc3f..b480408f 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -4,6 +4,7 @@ use crate::infer::InferStreamResponse; use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; +use async_stream::reexport::next; use text_generation_client::{Batch, Request}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; @@ -121,25 +122,24 @@ impl State { // Get the next batch fn next_batch(&mut self, min_size: Option, max_size: usize) -> Option { - // Check if we have enough entries in DB by comparing next batch id and current id + if self.entries.is_empty() { + return None; + } + + // Check if we have enough entries in DB if let Some(min_size) = min_size { if self.entries.len() < min_size { return None; } } - // If both ids are equal, the DB is empty - if self.entries.is_empty() { - return None; - } - let next_batch_size = min(self.entries.len(), max_size); - // Iterates for max_size over the BTreemap starting from next_batch_start_id - let mut batch_requests = Vec::new(); + let mut batch_requests = Vec::with_capacity(next_batch_size); let mut batch_entries = IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default()); + // Drain next_batch_size entries self.entries .drain(..next_batch_size) .for_each(|(id, mut entry)| { @@ -152,7 +152,7 @@ impl State { }); // Set batch_time entry.batch_time = Some(Instant::now()); - // Insert in entries IntMap + // Insert in batch_entries IntMap batch_entries.insert(id, entry); });