From 2af011a1c049fe5ddd5256c6d39a0270dbc4081e Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 8 Feb 2024 17:26:55 +0100 Subject: [PATCH] use max_size in the batch task --- router/src/infer.rs | 14 +++++++++++--- router/src/queue.rs | 45 +++++++++++++++++++++++---------------------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index f4441604..d6dbd842 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -70,7 +70,7 @@ impl Infer { tokenizer_config: HubTokenizerConfig, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding, max_batch_size, 16, window_size, speculate); + let queue = Queue::new(requires_padding, 16, window_size, speculate); let shared = Arc::new(Shared { batching_task: Notify::new(), }); @@ -82,6 +82,7 @@ impl Infer { max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, + max_batch_size, queue.clone(), shared.clone(), generation_health, @@ -339,6 +340,7 @@ async fn batching_task( max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, + max_batch_size: Option, queue: Queue, shared: Arc, generation_health: Arc, @@ -352,7 +354,12 @@ async fn batching_task( // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue while let Some((mut entries, batch, span)) = queue - .next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens) + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) .await { let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) @@ -380,10 +387,11 @@ async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let max_size = max_batch_size.map(|max_size| batch_size as usize - max_size); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_batch_prefill_tokens, token_budget) + .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) .await { // Tracking metrics diff --git a/router/src/queue.rs b/router/src/queue.rs index 8d855049..b9db493d 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -36,7 +36,6 @@ pub(crate) struct Queue { impl Queue { pub(crate) fn new( requires_padding: bool, - max_batch_size: Option, block_size: u32, window_size: Option, speculate: u32, @@ -47,7 +46,6 @@ impl Queue { // Launch background queue task tokio::spawn(queue_task( requires_padding, - max_batch_size, block_size, window_size, speculate, @@ -72,6 +70,7 @@ impl Queue { pub(crate) async fn next_batch( &self, min_size: Option, + max_size: Option, prefill_token_budget: u32, token_budget: u32, ) -> Option { @@ -82,6 +81,7 @@ impl Queue { self.queue_sender .send(QueueCommand::NextBatch { min_size, + max_size, prefill_token_budget, token_budget, response_sender, @@ -97,7 +97,6 @@ impl Queue { // Background task responsible of the queue state async fn queue_task( requires_padding: bool, - max_size: Option, block_size: u32, window_size: Option, speculate: u32, @@ -113,6 +112,7 @@ async fn queue_task( } QueueCommand::NextBatch { min_size, + max_size, prefill_token_budget, token_budget, response_sender, @@ -332,6 +332,7 @@ enum QueueCommand { Append(Box, Span), NextBatch { min_size: Option, + max_size: Option, prefill_token_budget: u32, token_budget: u32, response_sender: oneshot::Sender>, @@ -494,28 +495,28 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, None, 1, None, 0); + let queue = Queue::new(false, 1, None, 0); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, None, 1, None, 0); + let queue = Queue::new(false, 1, None, 0); - assert!(queue.next_batch(None, 1, 1).await.is_none()); - assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); + assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); } #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, None, 1, None, 0); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -528,11 +529,11 @@ mod tests { queue.append(entry3); // Not enough requests pending - assert!(queue.next_batch(Some(2), 2, 2).await.is_none()); + assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none()); // Not enough token budget - assert!(queue.next_batch(Some(1), 0, 0).await.is_none()); + assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none()); // Ok - let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap(); + let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap(); assert_eq!(entries2.len(), 1); assert!(entries2.contains_key(&2)); assert!(entries2.get(&2).unwrap().batch_time.is_some()); @@ -542,13 +543,13 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, Some(1), 1, None, 0); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert!(entries.get(&0).unwrap().batch_time.is_some()); @@ -558,13 +559,13 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, None, 1, None, 0); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -573,7 +574,7 @@ mod tests { let (entry3, _guard3) = default_entry(); queue.append(entry3); - let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -583,16 +584,16 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, None, 1, None, 2); + let queue = Queue::new(false, 1, None, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); queue.append(entry2); // Budget of 1 is not enough - assert!(queue.next_batch(None, 1, 1).await.is_none()); + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); - let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -602,10 +603,10 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, None, 1, None, 0); + let queue = Queue::new(false, 1, None, 0); let (entry, _) = default_entry(); queue.append(entry); - assert!(queue.next_batch(None, 1, 1).await.is_none()); + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); } }