diff --git a/router/src/infer.rs b/router/src/infer.rs index 60941fa6..09e969a9 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -240,7 +240,6 @@ impl Infer { /// Batches requests and sends them to the inference server async fn batching_task( mut client: ShardedClient, - // max_batch_size: usize, max_waiting_tokens: usize, queue: Queue, shared: Arc, diff --git a/router/src/queue.rs b/router/src/queue.rs index 098a337a..4084eb6c 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -108,8 +108,11 @@ async fn queue_task( #[derive(Debug)] pub(crate) struct BatchingConfig { + /// Upper bound on number of requests in a batch pub(crate) size_limit: usize, + /// Maximum batch "weight" at any point of time (takes sequence lengths into account) pub(crate) weight_limit: usize, + /// Maximum weight of individual prefill batches pub(crate) prefill_weight_limit: usize, } @@ -214,6 +217,8 @@ impl BatchType for FlashBatch { for (bs, (ol, il, _)) in tree.iter().rev().enumerate() { let this_ol = *ol; in_sum += *il; + // Only need to check segments with output_len > current_output_len + // will have been checked in a prior iteration if this_ol <= current_output_len { // Check if we breach max space for this segment let token_count = in_sum + (bs + 1) * this_ol; @@ -244,7 +249,7 @@ impl BatchType for PaddedBatch { fn batch_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize { let (max_input_length, max_output_length) = max_in_out_lengths; let max_seq_len = max_input_length + max_output_length; - // Memory requirement roughly propotionall to batch_size * seq_len^2 + // Memory requirement roughly proportional to batch_size * seq_len^2 batch_size * max_seq_len.pow(2) } diff --git a/router/src/server.rs b/router/src/server.rs index 817dcbba..5ceaa930 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -635,7 +635,7 @@ async fn do_run( validation_workers: usize, addr: SocketAddr, allow_origin: Option, - _batch_type: B, + batch_type: B, ) { // OpenAPI documentation #[derive(OpenApi)] @@ -701,7 +701,7 @@ async fn do_run( max_prefill_weight, max_waiting_tokens, max_concurrent_requests, - FlashBatch{} + batch_type, ); // Duration buckets