diff --git a/router/src/main.rs b/router/src/main.rs index 2f793040..03e920c6 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -220,7 +220,8 @@ fn main() -> Result<(), RouterError> { .map_err(RouterError::Warmup)? { // Older models do not support automatic max-batch-total-tokens - None => max_batch_total_tokens.unwrap_or(16000), + None => max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))), // Flash attention models return their max supported total tokens Some(max_supported_batch_total_tokens) => { // Warn if user added his own max-batch-total-tokens as we will ignore it