diff --git a/router/src/infer.rs b/router/src/infer.rs index d0d22d3b..188ddc64 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -53,7 +53,7 @@ impl Infer { generation_health: Arc, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding); + let queue = Queue::new(requires_padding, 16); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/router/src/queue.rs b/router/src/queue.rs index 60831c3b..2d8d6d1c 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -33,12 +33,12 @@ pub(crate) struct Queue { } impl Queue { - pub(crate) fn new(requires_padding: bool) -> Self { + pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self { // Create channel let (queue_sender, queue_receiver) = flume::unbounded(); // Launch background queue task - tokio::spawn(queue_task(requires_padding, queue_receiver)); + tokio::spawn(queue_task(requires_padding, block_size, queue_receiver)); Self { queue_sender } } @@ -81,8 +81,12 @@ impl Queue { } // Background task responsible of the queue state -async fn queue_task(requires_padding: bool, receiver: flume::Receiver) { - let mut state = State::new(requires_padding); +async fn queue_task( + requires_padding: bool, + block_size: u32, + receiver: flume::Receiver, +) { + let mut state = State::new(requires_padding, block_size); while let Ok(cmd) = receiver.recv_async().await { match cmd { @@ -119,15 +123,19 @@ struct State { /// Whether the model is using padding requires_padding: bool, + + /// Paged Attention block size + block_size: u32, } impl State { - fn new(requires_padding: bool) -> Self { + fn new(requires_padding: bool, block_size: u32) -> Self { Self { entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, requires_padding, + block_size, } } @@ -188,7 +196,9 @@ impl State { prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length } else { // pad to block size - prefill_tokens += ((entry.request.input_length + 16 - 1) / 16) * 16; + prefill_tokens += ((entry.request.input_length + self.block_size - 1) + / self.block_size) + * self.block_size; } if self.requires_padding { @@ -196,7 +206,9 @@ impl State { } else { // pad to block size decode_tokens += - ((entry.request.stopping_parameters.max_new_tokens + 16 - 1) / 16) * 16; + ((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1) + / self.block_size) + * self.block_size; } if prefill_tokens > prefill_token_budget @@ -328,7 +340,7 @@ mod tests { #[test] fn test_append() { - let mut state = State::new(false); + let mut state = State::new(false, 1); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -344,7 +356,7 @@ mod tests { #[test] fn test_next_batch_empty() { - let mut state = State::new(false); + let mut state = State::new(false, 1); assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none()); @@ -352,7 +364,7 @@ mod tests { #[test] fn test_next_batch_min_size() { - let mut state = State::new(false); + let mut state = State::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -384,7 +396,7 @@ mod tests { #[test] fn test_next_batch_token_budget() { - let mut state = State::new(false); + let mut state = State::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -417,14 +429,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); @@ -432,7 +444,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -465,7 +477,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -490,7 +502,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry, _) = default_entry(); queue.append(entry);