From faaa9dfe0aaa8441a1cb556e713bbfc8e86eeb74 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 8 Feb 2024 17:01:20 +0100 Subject: [PATCH] feat(router): add max_batch_size --- Cargo.lock | 38 ++++++++++++++- launcher/src/main.rs | 11 +++++ router/client/src/client.rs | 6 +++ router/client/src/sharded_client.rs | 8 +++- router/src/infer.rs | 3 +- router/src/lib.rs | 2 + router/src/main.rs | 5 ++ router/src/queue.rs | 72 +++++++++++++++++++++++------ router/src/server.rs | 3 ++ 9 files changed, 131 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7fdf301a..3318f3b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2787,7 +2787,7 @@ dependencies = [ "tabled", "text-generation-client", "thiserror", - "tokenizers", + "tokenizers 0.14.1", "tokio", "tracing", "tracing-subscriber", @@ -2850,7 +2850,7 @@ dependencies = [ "serde_json", "text-generation-client", "thiserror", - "tokenizers", + "tokenizers 0.15.1", "tokio", "tokio-stream", "tower-http", @@ -2972,6 +2972,40 @@ dependencies = [ "unicode_categories", ] +[[package]] +name = "tokenizers" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812" +dependencies = [ + "aho-corasick", + "clap", + "derive_builder", + "esaxx-rs", + "getrandom", + "hf-hub", + "indicatif", + "itertools 0.11.0", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.7.5", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.35.1" diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 054e546c..a51742e6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -279,6 +279,11 @@ struct Args { #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, + /// Enforce a maximum number of requests per batch + /// Specific flag for hardware targets that do not support unpadded inference + #[clap(long, env)] + max_batch_size: Option, + /// The IP address to listen on #[clap(default_value = "0.0.0.0", long, env)] hostname: String, @@ -1046,6 +1051,12 @@ fn spawn_webserver( router_args.push(max_batch_total_tokens.to_string()); } + // Router optional max batch size + if let Some(max_batch_size) = args.max_batch_size { + router_args.push("--max-batch-size".to_string()); + router_args.push(max_batch_size.to_string()); + } + // Model optional revision if let Some(ref revision) = args.revision { router_args.push("--revision".to_string()); diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 023c5671..b5fa042e 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -105,6 +105,7 @@ impl Client { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_size: Option, ) -> Result> { let mut n_tokens = 0; let mut requests = Vec::new(); @@ -136,6 +137,11 @@ impl Client { top_n_tokens: 20, }); n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } } let batch = Batch { diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index f0e65ce5..e1e52d59 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -97,12 +97,18 @@ impl ShardedClient { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_size: Option, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() .map(|client| { - Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)) + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_size, + )) }) .collect(); // Take the minimum value diff --git a/router/src/infer.rs b/router/src/infer.rs index 4da0da0a..f4441604 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -61,6 +61,7 @@ impl Infer { max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, + max_batch_size: Option, max_concurrent_requests: usize, requires_padding: bool, window_size: Option, @@ -69,7 +70,7 @@ impl Infer { tokenizer_config: HubTokenizerConfig, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding, 16, window_size, speculate); + let queue = Queue::new(requires_padding, max_batch_size, 16, window_size, speculate); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/router/src/lib.rs b/router/src/lib.rs index e85519cc..c88824ce 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -73,6 +73,8 @@ pub struct Info { pub max_batch_total_tokens: u32, #[schema(example = "20")] pub max_waiting_tokens: usize, + #[schema(nullable = true, example = "null")] + pub max_batch_size: Option, #[schema(example = "2")] pub validation_workers: usize, /// Router Info diff --git a/router/src/main.rs b/router/src/main.rs index 2a080468..a1f8d97b 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -45,6 +45,8 @@ struct Args { max_batch_total_tokens: Option, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, #[clap(default_value = "0.0.0.0", long, env)] hostname: String, #[clap(default_value = "3000", long, short, env)] @@ -91,6 +93,7 @@ async fn main() -> Result<(), RouterError> { max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, + max_batch_size, hostname, port, master_shard_uds_path, @@ -288,6 +291,7 @@ async fn main() -> Result<(), RouterError> { max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32, + max_batch_size, ) .await .map_err(RouterError::Warmup)? @@ -344,6 +348,7 @@ async fn main() -> Result<(), RouterError> { max_batch_prefill_tokens, max_supported_batch_total_tokens, max_waiting_tokens, + max_batch_size, sharded_client, tokenizer, validation_workers, diff --git a/router/src/queue.rs b/router/src/queue.rs index 106cacc4..8d855049 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -36,6 +36,7 @@ pub(crate) struct Queue { impl Queue { pub(crate) fn new( requires_padding: bool, + max_batch_size: Option, block_size: u32, window_size: Option, speculate: u32, @@ -46,6 +47,7 @@ impl Queue { // Launch background queue task tokio::spawn(queue_task( requires_padding, + max_batch_size, block_size, window_size, speculate, @@ -95,6 +97,7 @@ impl Queue { // Background task responsible of the queue state async fn queue_task( requires_padding: bool, + max_size: Option, block_size: u32, window_size: Option, speculate: u32, @@ -115,7 +118,8 @@ async fn queue_task( response_sender, span, } => span.in_scope(|| { - let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget); + let next_batch = + state.next_batch(min_size, max_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size", state.entries.len() as f64); }), @@ -181,6 +185,7 @@ impl State { fn next_batch( &mut self, min_size: Option, + max_size: Option, prefill_token_budget: u32, token_budget: u32, ) -> Option { @@ -274,6 +279,11 @@ impl State { entry.batch_time = Some(Instant::now()); // Insert in batch_entries IntMap batch_entries.insert(id, entry); + + // Check if max_size + if Some(batch_requests.len()) == max_size { + break; + } } // Empty batch @@ -393,8 +403,8 @@ mod tests { fn test_next_batch_empty() { let mut state = State::new(false, 1, None, 0); - assert!(state.next_batch(None, 1, 1).is_none()); - assert!(state.next_batch(Some(1), 1, 1).is_none()); + assert!(state.next_batch(None, None, 1, 1).is_none()); + assert!(state.next_batch(Some(1), None, 1, 1).is_none()); } #[test] @@ -405,7 +415,7 @@ mod tests { state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -421,7 +431,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - assert!(state.next_batch(Some(2), 2, 2).is_none()); + assert!(state.next_batch(Some(2), None, 2, 2).is_none()); assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); @@ -429,6 +439,26 @@ mod tests { assert_eq!(id, 2); } + #[test] + fn test_next_batch_max_size() { + let mut state = State::new(false, 1, None, 0); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + state.append(entry1); + state.append(entry2); + + let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 1); + assert_eq!(state.next_batch_id, 1); + } + #[test] fn test_next_batch_token_budget() { let mut state = State::new(false, 1, None, 0); @@ -437,7 +467,7 @@ mod tests { state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -450,7 +480,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -464,14 +494,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, None, 1, None, 0); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, None, 1, None, 0); assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); @@ -479,7 +509,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, None, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -510,9 +540,25 @@ mod tests { assert_eq!(batch2.size, 1); } + #[tokio::test] + async fn test_queue_next_batch_max_size() { + let queue = Queue::new(false, Some(1), 1, None, 0); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + } + #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, None, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -537,7 +583,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2); + let queue = Queue::new(false, None, 1, None, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -556,7 +602,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, None, 1, None, 0); let (entry, _) = default_entry(); queue.append(entry); diff --git a/router/src/server.rs b/router/src/server.rs index b4d26158..aa9b0c1a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -762,6 +762,7 @@ pub async fn run( max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, + max_batch_size: Option, client: ShardedClient, tokenizer: Option, validation_workers: usize, @@ -843,6 +844,7 @@ pub async fn run( max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, + max_batch_size, max_concurrent_requests, shard_info.requires_padding, shard_info.window_size, @@ -924,6 +926,7 @@ pub async fn run( waiting_served_ratio, max_batch_total_tokens, max_waiting_tokens, + max_batch_size, validation_workers, version: env!("CARGO_PKG_VERSION"), sha: option_env!("VERGEN_GIT_SHA"),