diff --git a/router/src/queue.rs b/router/src/queue.rs index 48e483a1..60831c3b 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -187,10 +187,17 @@ impl State { max_input_length = max_input_length.max(entry.request.input_length); prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length } else { - prefill_tokens += entry.request.input_length; + // pad to block size + prefill_tokens += ((entry.request.input_length + 16 - 1) / 16) * 16; } - decode_tokens += entry.request.stopping_parameters.max_new_tokens; + if self.requires_padding { + decode_tokens += entry.request.stopping_parameters.max_new_tokens; + } else { + // pad to block size + decode_tokens += + ((entry.request.stopping_parameters.max_new_tokens + 16 - 1) / 16) * 16; + } if prefill_tokens > prefill_token_budget || (prefill_tokens + decode_tokens) > token_budget