diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 712b4fc43..ba54f0586 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -197,6 +197,14 @@ Options: [env: MAX_WAITING_TOKENS=] [default: 20] +``` +## MAX_BATCH_SIZE +```shell + --max-batch-size + Enforce a maximum number of requests per batch Specific flag for hardware targets that do not support unpadded inference + + [env: MAX_BATCH_SIZE=] + ``` ## HOSTNAME ```shell diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 53a40ea8f..428b00c12 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -281,6 +281,11 @@ struct Args { #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, + /// Enforce a maximum number of requests per batch + /// Specific flag for hardware targets that do not support unpadded inference + #[clap(long, env)] + max_batch_size: Option, + /// The IP address to listen on #[clap(default_value = "0.0.0.0", long, env)] hostname: String, @@ -1056,6 +1061,12 @@ fn spawn_webserver( router_args.push(max_batch_total_tokens.to_string()); } + // Router optional max batch size + if let Some(max_batch_size) = args.max_batch_size { + router_args.push("--max-batch-size".to_string()); + router_args.push(max_batch_size.to_string()); + } + // Model optional revision if let Some(ref revision) = args.revision { router_args.push("--revision".to_string()); diff --git a/router/client/src/client.rs b/router/client/src/client.rs index c61a40032..592338fa2 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -109,7 +109,7 @@ impl Client { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, - max_batch_total_tokens: Option, + max_batch_size: Option, ) -> Result> { let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true"); if !warmup_enabled { @@ -142,17 +142,9 @@ impl Client { } } + // if max_batch_size is None, create two batches + let num_batches = max_batch_size.unwrap_or(2).min(2); let mut id_counter: u64 = 0; - let num_batches = match max_batch_total_tokens { - Some(val) => { - if val == max_total_tokens { - 1 - } else { - 2 - } - } - None => 2, // If max_batch_total_tokens is None, create two batches - }; for shape in shapes.iter() { // create two batches in order to trigger concatenate operation // in case decode bs=1 create one batch diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 8d81da6a9..e2c800dd2 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -99,13 +99,18 @@ impl ShardedClient { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, - max_batch_total_tokens: Option, + max_batch_size: Option, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() .map(|client| { - Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens, max_batch_total_tokens)) + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_size, + )) }) .collect(); // Take the minimum value diff --git a/router/src/infer.rs b/router/src/infer.rs index 7b6b99106..48369de91 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -63,6 +63,7 @@ impl Infer { max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, + max_batch_size: Option, max_concurrent_requests: usize, requires_padding: bool, max_input_length: u32, @@ -92,6 +93,7 @@ impl Infer { max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, + max_batch_size, queue.clone(), shared.clone(), generation_health, @@ -349,6 +351,7 @@ async fn batching_task( max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, + max_batch_size: Option, queue: Queue, shared: Arc, generation_health: Arc, @@ -362,7 +365,12 @@ async fn batching_task( // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue while let Some((mut entries, batch, span)) = queue - .next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens) + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) .await { let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) @@ -390,10 +398,11 @@ async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_batch_prefill_tokens, token_budget) + .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) .await { // Tracking metrics diff --git a/router/src/lib.rs b/router/src/lib.rs index 7c44d642e..3ce9eca8b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -73,6 +73,8 @@ pub struct Info { pub max_batch_total_tokens: u32, #[schema(example = "20")] pub max_waiting_tokens: usize, + #[schema(nullable = true, example = "null")] + pub max_batch_size: Option, #[schema(example = "2")] pub validation_workers: usize, /// Router Info diff --git a/router/src/main.rs b/router/src/main.rs index 702393aa9..1757e4596 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -48,6 +48,8 @@ struct Args { max_batch_total_tokens: Option, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, #[clap(default_value = "0.0.0.0", long, env)] hostname: String, #[clap(default_value = "3000", long, short, env)] @@ -94,6 +96,7 @@ async fn main() -> Result<(), RouterError> { max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, + max_batch_size, hostname, port, master_shard_uds_path, @@ -138,6 +141,25 @@ async fn main() -> Result<(), RouterError> { } } + let (max_batch_size, max_batch_total_tokens) = match (max_batch_size, max_batch_total_tokens) { + (Some(_max_batch_size), Some(_max_batch_total_tokens)) => { + if (_max_batch_total_tokens as usize / max_total_tokens) != _max_batch_size { + tracing::warn!("max_batch_size was set to {_max_batch_size} while max_batch_total_tokens to {_max_batch_total_tokens}"); + tracing::warn!("These values are not match, so max_batch_size will be preferred"); + (Some(_max_batch_size), Some((_max_batch_size * max_total_tokens) as u32)) + } else { + (Some(_max_batch_size), Some(_max_batch_total_tokens)) + } + }, + (Some(_max_batch_size), None) => ( + Some(_max_batch_size), Some((_max_batch_size * max_total_tokens) as u32) + ), + (None, Some(_max_batch_total_tokens)) => ( + Some(_max_batch_total_tokens as usize / max_total_tokens), Some(_max_batch_total_tokens) + ), + (None, None) => (None, None), + }; + // CORS allowed origins // map to go inside the option and then map to parse from String to HeaderValue // Finally, convert to AllowOrigin @@ -298,7 +320,7 @@ async fn main() -> Result<(), RouterError> { max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32, - max_batch_total_tokens, + max_batch_size, ) .await .map_err(RouterError::Warmup)? @@ -355,6 +377,7 @@ async fn main() -> Result<(), RouterError> { max_batch_prefill_tokens, max_supported_batch_total_tokens, max_waiting_tokens, + max_batch_size, sharded_client, tokenizer, validation_workers, diff --git a/router/src/queue.rs b/router/src/queue.rs index 9e3494f7b..00021812f 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -79,6 +79,7 @@ impl Queue { pub(crate) async fn next_batch( &self, min_size: Option, + max_size: Option, prefill_token_budget: u32, token_budget: u32, ) -> Option { @@ -89,6 +90,7 @@ impl Queue { self.queue_sender .send(QueueCommand::NextBatch { min_size, + max_size, prefill_token_budget, token_budget, response_sender, @@ -128,12 +130,14 @@ async fn queue_task( } QueueCommand::NextBatch { min_size, + max_size, prefill_token_budget, token_budget, response_sender, span, } => span.in_scope(|| { - let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget); + let next_batch = + state.next_batch(min_size, max_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size", state.entries.len() as f64); }), @@ -308,6 +312,7 @@ impl State { fn next_batch( &mut self, min_size: Option, + max_size: Option, prefill_token_budget: u32, token_budget: u32, ) -> Option { @@ -403,6 +408,11 @@ impl State { entry.batch_time = Some(Instant::now()); // Insert in batch_entries IntMap batch_entries.insert(id, entry); + + // Check if max_size + if Some(batch_requests.len()) == max_size { + break; + } } // Empty batch @@ -451,6 +461,7 @@ enum QueueCommand { Append(Box, Span), NextBatch { min_size: Option, + max_size: Option, prefill_token_budget: u32, token_budget: u32, response_sender: oneshot::Sender>, @@ -535,8 +546,8 @@ mod tests { fn test_next_batch_empty() { let mut state = default_state(); - assert!(state.next_batch(None, 1, 1).is_none()); - assert!(state.next_batch(Some(1), 1, 1).is_none()); + assert!(state.next_batch(None, None, 1, 1).is_none()); + assert!(state.next_batch(Some(1), None, 1, 1).is_none()); } #[test] @@ -547,7 +558,7 @@ mod tests { state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, 2, 4).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 2, 4).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -563,7 +574,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - assert!(state.next_batch(Some(2), 2, 2).is_none()); + assert!(state.next_batch(Some(2), None, 2, 2).is_none()); assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); @@ -571,6 +582,26 @@ mod tests { assert_eq!(id, 2); } + #[test] + fn test_next_batch_max_size() { + let mut state = default_state(); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + state.append(entry1); + state.append(entry2); + + let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 1); + assert_eq!(state.next_batch_id, 1); + } + #[test] fn test_next_batch_token_budget() { let mut state = default_state(); @@ -579,7 +610,7 @@ mod tests { state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, 1, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 1, 2).unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -592,7 +623,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, 3, 6).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 3, 6).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -615,8 +646,8 @@ mod tests { async fn test_queue_next_batch_empty() { let queue = default_queue(); - assert!(queue.next_batch(None, 1, 1).await.is_none()); - assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); + assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); } #[tokio::test] @@ -627,7 +658,7 @@ mod tests { queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 2, 4).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -640,11 +671,11 @@ mod tests { queue.append(entry3); // Not enough requests pending - assert!(queue.next_batch(Some(2), 2, 2).await.is_none()); + assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none()); // Not enough token budget - assert!(queue.next_batch(Some(1), 0, 0).await.is_none()); + assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none()); // Ok - let (entries2, batch2, _) = queue.next_batch(Some(1), 1, 2).await.unwrap(); + let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 4).await.unwrap(); assert_eq!(entries2.len(), 1); assert!(entries2.contains_key(&2)); assert!(entries2.get(&2).unwrap().batch_time.is_some()); @@ -652,6 +683,22 @@ mod tests { assert_eq!(batch2.size, 1); } + #[tokio::test] + async fn test_queue_next_batch_max_size() { + let queue = default_queue(); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + } + #[tokio::test] async fn test_queue_next_batch_token_budget() { let queue = default_queue(); @@ -660,7 +707,7 @@ mod tests { queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 1, 2).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 1, 2).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -669,7 +716,7 @@ mod tests { let (entry3, _guard3) = default_entry(); queue.append(entry3); - let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 3, 6).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -686,9 +733,9 @@ mod tests { queue.append(entry2); // Budget of 1 is not enough - assert!(queue.next_batch(None, 1, 1).await.is_none()); + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); - let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -702,6 +749,6 @@ mod tests { let (entry, _) = default_entry(); queue.append(entry); - assert!(queue.next_batch(None, 1, 1).await.is_none()); + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); } } diff --git a/router/src/server.rs b/router/src/server.rs index 15ad6b336..450494df6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -770,6 +770,7 @@ pub async fn run( max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, + max_batch_size: Option, client: ShardedClient, tokenizer: Option, validation_workers: usize, @@ -851,6 +852,7 @@ pub async fn run( max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, + max_batch_size, max_concurrent_requests, shard_info.requires_padding, max_input_length as u32, @@ -934,6 +936,7 @@ pub async fn run( waiting_served_ratio, max_batch_total_tokens, max_waiting_tokens, + max_batch_size, validation_workers, version: env!("CARGO_PKG_VERSION"), sha: option_env!("VERGEN_GIT_SHA"),