From 45e060e857f618d786730a97149502c5e7c72107 Mon Sep 17 00:00:00 2001 From: Venkat Raman Date: Thu, 26 Sep 2024 19:51:10 +0200 Subject: [PATCH] feat: propagate max_concurrent_requests to queue state entries instead of hardcoded 128 in backends/v3 --- backends/v3/src/backend.rs | 2 ++ backends/v3/src/lib.rs | 2 ++ backends/v3/src/main.rs | 1 + backends/v3/src/queue.rs | 31 ++++++++++++++++++------------- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index f8a10ca2..e5ee46cb 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -31,6 +31,7 @@ impl BackendV3 { max_batch_total_tokens: u32, max_waiting_tokens: usize, max_batch_size: Option, + max_concurrent_requests: usize, requires_padding: bool, window_size: Option, speculate: u32, @@ -46,6 +47,7 @@ impl BackendV3 { let block_size = attention.block_size(); let queue = Queue::new( + max_concurrent_requests, requires_padding, block_size, prefix_caching, diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index 77a9a11a..73589320 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -41,6 +41,7 @@ pub async fn connect_backend( max_batch_total_tokens: Option, max_waiting_tokens: usize, max_batch_size: Option, + max_concurrent_requests: usize, ) -> Result<(BackendV3, BackendInfo), V3Error> { // Helper function let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { @@ -118,6 +119,7 @@ pub async fn connect_backend( max_batch_total_tokens, max_waiting_tokens, max_batch_size, + max_concurrent_requests, shard_info.requires_padding, shard_info.window_size, shard_info.speculate, diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 471ddb5a..90d8bcba 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -167,6 +167,7 @@ async fn main() -> Result<(), RouterError> { max_batch_total_tokens, max_waiting_tokens, max_batch_size, + max_concurrent_requests, ) .await?; diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index f8123b57..e2f4acf0 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -44,6 +44,7 @@ pub(crate) struct Queue { impl Queue { pub(crate) fn new( + max_concurrent_requests: usize, requires_padding: bool, block_size: u32, prefix_caching: bool, @@ -56,6 +57,7 @@ impl Queue { // Launch background queue task tokio::spawn(queue_task( + max_concurrent_requests, requires_padding, block_size, prefix_caching, @@ -109,6 +111,7 @@ impl Queue { // Background task responsible of the queue state async fn queue_task( + max_concurrent_requests: usize, requires_padding: bool, block_size: u32, prefix_caching: bool, @@ -118,6 +121,7 @@ async fn queue_task( mut receiver: mpsc::UnboundedReceiver, ) { let mut state = State::new( + max_concurrent_requests, requires_padding, block_size, prefix_caching, @@ -178,6 +182,7 @@ struct State { impl State { fn new( + max_concurrent_requests: usize, requires_padding: bool, block_size: u32, prefix_caching: bool, @@ -195,7 +200,7 @@ impl State { }); Self { - entries: VecDeque::with_capacity(128), + entries: VecDeque::with_capacity(max_concurrent_requests), next_id: 0, next_batch_id: 0, block_size, @@ -567,7 +572,7 @@ mod tests { #[tokio::test] async fn test_append() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(128, false, 1, false, None, 0, 16); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -583,7 +588,7 @@ mod tests { #[tokio::test] async fn test_next_batch_empty() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(128, false, 1, false, None, 0, 16); assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -591,7 +596,7 @@ mod tests { #[tokio::test] async fn test_next_batch_min_size() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(128, false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -623,7 +628,7 @@ mod tests { #[tokio::test] async fn test_next_batch_max_size() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(128, false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -643,7 +648,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, false, None, 0, 2); + let mut state = State::new(128, false, 1, false, None, 0, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -676,14 +681,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(128, false, 1, false, None, 0, 16); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(128, false, 1, false, None, 0, 16); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -691,7 +696,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(128, false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -724,7 +729,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(128, false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -740,7 +745,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(128, false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -765,7 +770,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, false, None, 2, 16); + let queue = Queue::new(128, false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -784,7 +789,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(128, false, 1, false, None, 0, 16); let (entry, _) = default_entry(); queue.append(entry);