diff --git a/router/src/infer.rs b/router/src/infer.rs index aa6dc664..79c5eb17 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -49,11 +49,12 @@ impl Infer { max_waiting_tokens: usize, max_concurrent_requests: usize, requires_padding: bool, + max_input_length: u32, window_size: Option, generation_health: Arc, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding, 16, window_size); + let queue = Queue::new(requires_padding, max_input_length, 16, window_size); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/router/src/queue.rs b/router/src/queue.rs index bbb8db0e..4e6e99a8 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -34,13 +34,19 @@ pub(crate) struct Queue { } impl Queue { - pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option) -> Self { + pub(crate) fn new( + requires_padding: bool, + max_input_length: u32, + block_size: u32, + window_size: Option + ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); // Launch background queue task tokio::spawn(queue_task( requires_padding, + max_input_length, block_size, window_size, queue_receiver, @@ -89,11 +95,12 @@ impl Queue { // Background task responsible of the queue state async fn queue_task( requires_padding: bool, + max_input_length: u32, block_size: u32, window_size: Option, mut receiver: mpsc::UnboundedReceiver, ) { - let mut state = State::new(requires_padding, block_size, window_size); + let mut state = State::new(requires_padding, max_input_length, block_size, window_size); while let Some(cmd) = receiver.recv().await { match cmd { @@ -131,6 +138,9 @@ struct State { /// Whether the model is using padding requires_padding: bool, + /// Maximum inpult length, required for padding scenario + max_input_length: u32, + /// Paged Attention block size block_size: u32, @@ -139,12 +149,18 @@ struct State { } impl State { - fn new(requires_padding: bool, block_size: u32, window_size: Option) -> Self { + fn new( + requires_padding: bool, + max_input_length: u32, + block_size: u32, + window_size: Option + ) -> Self { Self { entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, requires_padding, + max_input_length, block_size, window_size, } @@ -187,7 +203,6 @@ impl State { let mut batch_entries = IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); - let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; @@ -203,8 +218,7 @@ impl State { if self.requires_padding { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation - max_input_length = max_input_length.max(entry.request.input_length); - prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length + prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length } else { // pad to block size prefill_tokens += ((entry.request.input_length + self.block_size - 1) diff --git a/router/src/server.rs b/router/src/server.rs index f254afd8..97bc20c2 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -595,6 +595,7 @@ pub async fn run( max_waiting_tokens, max_concurrent_requests, shard_info.requires_padding, + max_input_length as u32, shard_info.window_size, generation_health, );