From ba1aae3e78356cd082dff70aee45d8981819cd2b Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 19 Apr 2023 07:58:40 -0700 Subject: [PATCH] feat(router): dynamic batch sizing --- launcher/src/main.rs | 16 ++ router/src/infer.rs | 78 +++--- router/src/main.rs | 8 + router/src/queue.rs | 517 +++++++++++++++++++-------------------- router/src/server.rs | 13 + router/src/validation.rs | 11 +- 6 files changed, 341 insertions(+), 302 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index fcac736d..1034ca00 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -41,6 +41,10 @@ struct Args { max_total_tokens: usize, #[clap(default_value = "32", long, env)] max_batch_size: usize, + #[clap(default_value = None, long, env)] + max_batch_weight: Option, + #[clap(default_value = None, long, env)] + max_prefill_weight: Option, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, #[clap(default_value = "3000", long, short, env)] @@ -93,6 +97,8 @@ fn main() -> ExitCode { max_input_length, max_total_tokens, max_batch_size, + max_batch_weight, + max_prefill_weight, max_waiting_tokens, port, shard_uds_path, @@ -392,6 +398,16 @@ fn main() -> ExitCode { model_id, ]; + if let Some(max_batch_weight) = max_batch_weight { + argv.push("--max-batch-weight".to_string()); + argv.push(max_batch_weight.to_string()) + } + + if let Some(max_prefill_weight) = max_prefill_weight { + argv.push("--max-batch-weight".to_string()); + argv.push(max_prefill_weight.to_string()) + } + // Model optional revision if let Some(ref revision) = revision { argv.push("--revision".to_string()); diff --git a/router/src/infer.rs b/router/src/infer.rs index 484720a0..21c34d0b 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -15,6 +15,7 @@ use thiserror::Error; use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; use tracing::{info_span, instrument, Instrument, Span}; +use crate::queue::BatchingConfig; /// Inference struct #[derive(Clone)] @@ -40,11 +41,17 @@ impl Infer { client: ShardedClient, validation: Validation, max_batch_size: usize, + max_batch_weight: usize, + max_prefill_weight: usize, max_waiting_tokens: usize, max_concurrent_requests: usize, ) -> Self { // Infer shared state - let queue = Queue::new(); + let queue = Queue::new(BatchingConfig { + size_limit: max_batch_size, + weight_limit: max_batch_weight, + prefill_weight_limit: max_prefill_weight, + }); let shared = Arc::new(Shared { batching_task: Notify::new(), }); @@ -52,7 +59,6 @@ impl Infer { // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( client, - max_batch_size, max_waiting_tokens, queue.clone(), shared.clone(), @@ -105,6 +111,7 @@ impl Infer { // Append the request to the queue self.queue.append(Entry { request: valid_request, + generated_tokens: 0, response_tx, span: Span::current(), temp_span: None, @@ -232,18 +239,11 @@ impl Infer { /// Batches requests and sends them to the inference server async fn batching_task( mut client: ShardedClient, - max_batch_size: usize, + // max_batch_size: usize, max_waiting_tokens: usize, queue: Queue, shared: Arc, ) { - // Minimum batch size after which we try to add more requests - let limit_min_batch_size = if max_batch_size > 1 { - (max_batch_size / 2) as u32 - } else { - 0 - }; - // Infinite loop loop { // Wait for a notification from the Infer struct @@ -252,8 +252,8 @@ async fn batching_task( // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue - while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await { - let mut cached_batch = prefill(&mut client, batch, &mut entries) + while let (_, Some((mut entries, batch, span))) = queue.next_batch(None).await { + let (mut cached_batch, mut some_completed) = prefill(&mut client, batch, &mut entries) .instrument(span) .await; let mut waiting_tokens = 1; @@ -266,21 +266,16 @@ async fn batching_task( let mut batches = vec![batch]; metrics::gauge!("tgi_batch_current_size", batch_size as f64); - // If the current batch is too small, we try to add more requests to it - if batch_size <= limit_min_batch_size { - let min_size = match waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - _ if waiting_tokens >= max_waiting_tokens => None, - // Minimum size criteria - _ => Some(limit_min_batch_size as usize), - }; + // Try to extend batch if its size reduced or enough tokens have elapsed since last one + if some_completed || waiting_tokens >= max_waiting_tokens { - // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_batch_size - batch_size as usize) - .await - { + // Try to get a new batch - ownership of entries passed in and out + let ( + existing_entries, new_entries + ) = queue.next_batch(Some(entries)).await; + entries = existing_entries.unwrap(); + + if let Some((mut new_entries, new_batch, span)) = new_entries { entries.iter_mut().for_each(|(_, entry)| { // Create a new span to add the info that this entry is waiting // because a new batch is being computed @@ -293,7 +288,7 @@ async fn batching_task( }); // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + let (new_cached_batch, _) = prefill(&mut client, new_batch, &mut new_entries) .instrument(span) .await; // Reset waiting counter @@ -319,7 +314,7 @@ async fn batching_task( entry.temp_span = Some(entry_batch_span); }); - cached_batch = decode(&mut client, batches, &mut entries) + (cached_batch, some_completed) = decode(&mut client, batches, &mut entries) .instrument(next_batch_span) .await; waiting_tokens += 1; @@ -334,14 +329,14 @@ async fn prefill( client: &mut ShardedClient, batch: Batch, entries: &mut IntMap, -) -> Option { +) -> (Option, bool) { let start_time = Instant::now(); let batch_id = batch.id; metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); match client.prefill(batch).await { Ok((generations, next_batch)) => { - filter_send_generations(generations, entries); + let some_completed = filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped let next_batch = match next_batch { @@ -360,14 +355,14 @@ async fn prefill( metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); - next_batch + (next_batch, some_completed) } // If we have an error, we discard the whole batch Err(err) => { let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); - None + (None, true) } } } @@ -377,14 +372,14 @@ async fn decode( client: &mut ShardedClient, batches: Vec, entries: &mut IntMap, -) -> Option { +) -> (Option, bool) { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); match client.decode(batches).await { Ok((generations, next_batch)) => { - filter_send_generations(generations, entries); + let some_completed = filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped let next_batch = match next_batch { @@ -403,7 +398,7 @@ async fn decode( metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); - next_batch + (next_batch, some_completed) } // If we have an error, we discard the whole batch Err(err) => { @@ -412,7 +407,7 @@ async fn decode( } send_errors(err, entries); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); - None + (None, true) } } } @@ -431,14 +426,16 @@ fn filter_batch(mut batch: Batch, entries: &IntMap) -> Option /// Send one or multiple `InferStreamResponse` to Infer for all `entries` /// and filter entries +/// Return true if any requests completed #[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { +fn filter_send_generations(generations: Vec, entries: &mut IntMap) -> bool { + let mut some_stopped = false; generations.into_iter().for_each(|generation| { let id = generation.request_id; // Get entry // We can `expect` here as the request id should always be in the entries let entry = entries - .get(&id) + .get_mut(&id) .expect("ID not found in entries. This is a bug."); // Create and enter a span to link this function back to the entry @@ -451,9 +448,14 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap, + #[clap(default_value = None, long, env)] + max_prefill_weight: Option, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, #[clap(default_value = "3000", long, short, env)] @@ -64,6 +68,8 @@ fn main() -> Result<(), std::io::Error> { max_input_length, max_total_tokens, max_batch_size, + max_batch_weight, + max_prefill_weight, max_waiting_tokens, port, master_shard_uds_path, @@ -169,6 +175,8 @@ fn main() -> Result<(), std::io::Error> { max_input_length, max_total_tokens, max_batch_size, + max_batch_weight, + max_prefill_weight, max_waiting_tokens, sharded_client, tokenizer, diff --git a/router/src/queue.rs b/router/src/queue.rs index 43651ff3..863025b3 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -2,8 +2,9 @@ use crate::infer::InferError; use crate::infer::InferStreamResponse; use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::min; -use std::collections::VecDeque; +use std::collections::{BTreeSet, VecDeque}; +use std::ops::Add; +use std::time::Duration; use text_generation_client::{Batch, Request}; use tokio::sync::oneshot; use tokio::time::Instant; @@ -14,6 +15,8 @@ use tracing::{info_span, instrument, Span}; pub(crate) struct Entry { /// Request pub request: ValidGenerateRequest, + /// Count of tokens generated so far + pub generated_tokens: usize, /// Response sender to communicate between the Infer struct and the batching_task pub response_tx: flume::Sender>, /// Span that will live as long as entry @@ -34,12 +37,12 @@ pub(crate) struct Queue { } impl Queue { - pub(crate) fn new() -> Self { + pub(crate) fn new(config: BatchingConfig) -> Self { // Create channel let (queue_sender, queue_receiver) = flume::unbounded(); // Launch background queue task - tokio::spawn(queue_task(queue_receiver)); + tokio::spawn(queue_task(queue_receiver, config)); Self { queue_sender } } @@ -54,21 +57,19 @@ impl Queue { .unwrap(); } - // Get the next batch + // Get the next batch - existing batch is returned unchanged #[instrument(skip(self))] pub(crate) async fn next_batch( &self, - min_size: Option, - max_size: usize, - ) -> Option { + entries: Option, + ) -> (Option, Option) { // Create response channel let (response_sender, response_receiver) = oneshot::channel(); // Send next batch command to the background task managing the state // Unwrap is safe here self.queue_sender .send(QueueCommand::NextBatch { - min_size, - max_size, + entries, response_sender, span: Span::current(), }) @@ -80,28 +81,37 @@ impl Queue { } // Background task responsible of the queue state -async fn queue_task(receiver: flume::Receiver) { - let mut state = State::new(); +async fn queue_task(receiver: flume::Receiver, config: BatchingConfig) { + let mut state = State::new(config); while let Ok(cmd) = receiver.recv_async().await { match cmd { QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)), QueueCommand::NextBatch { - min_size, - max_size, + entries, response_sender, span, } => span.in_scope(|| { - let next_batch = state.next_batch(min_size, max_size); - response_sender.send(next_batch).unwrap_or(()); + let response = state.next_batch(entries); + response_sender.send(response).unwrap_or(()); }), } } } +#[derive(Debug)] +pub(crate) struct BatchingConfig { + pub(crate) size_limit: usize, + pub(crate) weight_limit: usize, + pub(crate) prefill_weight_limit: usize, +} + /// Queue State #[derive(Debug)] struct State { + /// Batching configuration + config: BatchingConfig, + /// Queue entries organized in a Vec entries: VecDeque<(u64, Entry)>, @@ -110,14 +120,46 @@ struct State { /// Id of the next batch next_batch_id: u64, + + // Remembered size of the last batch, used to determine + // when entries have completed between calls to the + // next_batch function + last_seen_batch_size: usize, + + // Index in the queue up to which entries have been + // checked to see if they can fit into the current batch. + // Reset to zero when any existing entries complete + checked_request_count: usize, + + /// true if it's known that the current size of the + /// requests in the queue is too small to prefill an add-on batch + buffer_contents_insufficient: bool, + + /// Just a constant empty map to reuse + empty_map: ExistingBatch, } +// Could also make these configurable + +/// Longest that requests can be waiting before we ignore the minimum +/// size requirement when adding to a new batch +const MAX_WAITING_DURATION: Duration = Duration::from_secs(1); + +/// Maximum difference in arrival time that smaller requests can jump +/// ahead of larger ones in the queue +const CUTOFF_DURATION: Duration = Duration::from_secs(1); + impl State { - fn new() -> Self { + fn new(config: BatchingConfig) -> Self { Self { + config, entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, + last_seen_batch_size: 0, + checked_request_count: 0, + buffer_contents_insufficient: false, + empty_map: IntMap::default(), } } @@ -134,70 +176,221 @@ impl State { } // Get the next batch - fn next_batch(&mut self, min_size: Option, max_size: usize) -> Option { - if self.entries.is_empty() { - return None; + fn next_batch( + &mut self, existing_entries_opt: Option, + ) -> (Option, Option) { + + // Use ref to empty map in None case to simplify subsequent logic + let existing_entries = existing_entries_opt.as_ref().unwrap_or(&self.empty_map); + + let config = &self.config; + let mut total_count = existing_entries.len(); + if total_count >= config.size_limit { + // We are already at max batch size + return (existing_entries_opt, None) } - // Check if we have enough entries - if let Some(min_size) = min_size { - if self.entries.len() < min_size { - return None; + if total_count != self.last_seen_batch_size { + // Reset the count of checked requests if any completed since last check + self.checked_request_count = 0; + self.last_seen_batch_size = total_count + } + + // Filter entries where the response receiver was dropped (== entries where the request + // was dropped by the client) + let queue_len_before = self.entries.len(); + self.entries.retain_mut(|(_, entry)| !entry.response_tx.is_disconnected()); + if queue_len_before != self.entries.len() { + // Reset the count of checked requests if any in the queue were cancelled since last check + self.checked_request_count = 0; + } + + // This will generally be zero, but if no requests have been completed + // since last time, we don't need to reconsider those already checked + let mut checked_up_to_index = self.checked_request_count; + + if !existing_entries.is_empty() { + // If we don't have any new requests in the buffer to check + if self.entries.len() <= checked_up_to_index || + // Or the current buffer isn't large enough to satisfy the min prefill requirement + self.buffer_contents_insufficient && !self.next_entry_waiting_too_long() { + return (existing_entries_opt, None) } } - let max_batch_size = min(self.entries.len(), max_size); + // Indices into buffer of entries chosen to add to next batch or remove due to expiry + let mut chosen_indices = vec![]; + // Indices to drop due to client cancellation + let mut indices_to_drop = vec![]; + let mut btree = None; + let mut time_cutoff = None; + let mut hit_prefill_weight_limit = false; + + let mut total_token_count = existing_entries.iter().map( + |(_, e)| e.request.stopping_parameters.max_new_tokens + e.request.truncate + ).sum::() as usize; + + let mut prefill_size = 0; + // We first do a read-only pass over the queue to allow skipping over large entries + // that don't fit in the current batch to reach smaller entries that do + let mut queue_index = checked_up_to_index; + 'queue_loop: for (entry_id, entry) in self.entries.range(queue_index..) { + if matches!(time_cutoff, Some(t) if entry.queue_time > t) { + break + } + queue_index += 1; + if entry.response_tx.is_disconnected() { + // Eject cancelled entry from queue + indices_to_drop.push(queue_index); + continue + } + // This is the index into the queue after cancelled entries + // have been pruned + checked_up_to_index += 1; + + let input_len = entry.request.truncate as usize; + let output_len = entry.request.stopping_parameters.max_new_tokens as usize; + let next_total_token_count = total_token_count + input_len + output_len; + + // Avoid more granular analysis if possible + if next_total_token_count > config.weight_limit { + // We aren't sure whether this next request will fit, so populate + // a btree with the current batch of requests, the set of + // requests already evaluated, and this one, and perform more + // granular analysis to verify that the batch shape won't exceed + // the limit at any point + + // Allocate btree the first time it's required + let tree = btree.get_or_insert_with(|| { + let mut t = Box::new(BTreeSet::new()); + // Populate with records corresponding to all existing and pending entries + let pending = chosen_indices.iter() + .map(|i| self.entries.get(*i).unwrap()) + .map(|(eid, e)| (eid, e)); + for (eid, e) in existing_entries.iter().chain(pending) { + let generated_count = e.generated_tokens; + t.insert(( + e.request.stopping_parameters.max_new_tokens as usize - generated_count, + e.request.truncate as usize + e.generated_tokens, + eid, + )); + } + t + }); + // Add the current entry + tree.insert((output_len, input_len, entry_id)); + + // Perform analysis + let mut in_sum = 0; + // Work backwards from longest projected entry + for (bs, (ol, il, _)) in tree.iter().rev().enumerate() { + let this_ol = *ol; + in_sum += *il; + if this_ol <= output_len { + // Check if we breach max space for this segment + let token_count = in_sum + (bs + 1) * this_ol; + if token_count > config.weight_limit { + // Remove our tuple from the set + tree.remove(&(output_len, input_len, entry_id)); + time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION)); + continue 'queue_loop + } + } + } + } else if let Some(tree) = btree.as_mut() { + // If we initialized the btree for a prior request, keep it updated + tree.insert((output_len, input_len, entry_id)); + } + // Here, we can add this request to the batch without breaching memory limit + + if config.prefill_weight_limit > 0 { + // Also check whether adding this request will make the batch of new requests + // too expensive latency-wise to perform in a single forward-pass. + if prefill_size + input_len > config.prefill_weight_limit { + if let Some(tree) = btree.as_mut() { + // Remove our tuple from the set + tree.remove(&(output_len, input_len, entry_id)); + hit_prefill_weight_limit = true; + } + time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION)); + continue + } + } + + total_token_count = next_total_token_count; + prefill_size += input_len; + + chosen_indices.push(queue_index - 1); + total_count += 1; + if total_count >= config.size_limit { + break + } + } + + // Drop any cancelled requests + if !indices_to_drop.is_empty() { + indices_to_drop.iter().for_each(|i| { + self.entries.remove(*i); + }); + metrics::gauge!("tgi_queue_size", self.entries.len() as f64); + } + + let next_batch_size = chosen_indices.len(); + if next_batch_size == 0 { + // This gets reset to zero when any requests in the existing batch are removed + self.checked_request_count = checked_up_to_index; + return (existing_entries_opt, None) + } + self.checked_request_count = 0; + + if !hit_prefill_weight_limit && !existing_entries.is_empty() { + // If this is to be added to an existing batch, ensure it meets urgency or size + // requirements to avoid too frequent prefills + if !self.next_entry_waiting_too_long() { + if total_token_count < config.weight_limit / 2 { + // Don't add this new batch yet because it's not large enough + self.checked_request_count = checked_up_to_index; + self.buffer_contents_insufficient = true; + return (existing_entries_opt, None) + } + } + } // Create span for this batch to add context to inference calls - let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); + let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size); next_batch_span.follows_from(&Span::current()); - let mut batch_requests = Vec::with_capacity(max_batch_size); let mut batch_entries = - IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default()); - - // Iterate on buffer - while let Some((id, mut entry)) = self.entries.pop_front() { - // Filter entries where the response receiver was dropped (== entries where the request - // was dropped by the client) - if entry.response_tx.is_disconnected() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - continue; - } + IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default()); + let some_now = Some(Instant::now()); + let batch_requests = chosen_indices.iter().enumerate().map(|(i, index)| { + let (id, mut entry) = self.entries.remove(index - i).expect("bug"); // Create a new span to link the batch back to this entry - let entry_batch_span = info_span!(parent: &entry.span, "infer"); + let entry_batch_span = + info_span!(parent: &entry.span, "infer", batch_size = next_batch_size); // Add relationships next_batch_span.follows_from(&entry_batch_span); entry_batch_span.follows_from(&next_batch_span); // Update entry entry.temp_span = Some(entry_batch_span); - batch_requests.push(Request { + let request = Request { id, inputs: entry.request.inputs.clone(), truncate: entry.request.truncate, parameters: Some(entry.request.parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()), - }); + }; // Set batch_time - entry.batch_time = Some(Instant::now()); + entry.batch_time = some_now; // Insert in batch_entries IntMap batch_entries.insert(id, entry); - - if batch_requests.len() == max_batch_size { - // We have enough requests in the batch - break; - } - } + request + }).collect::>(); metrics::gauge!("tgi_queue_size", self.entries.len() as f64); - // Maybe all entries were dropped because their channel were closed - if batch_requests.is_empty() { - return None; - } - // Final batch size once we dropped entries let size = batch_requests.len() as u32; next_batch_span.record("batch_size", size); @@ -209,224 +402,30 @@ impl State { }; // Increment batch id self.next_batch_id += 1; + self.buffer_contents_insufficient = false; metrics::histogram!("tgi_batch_next_size", batch.size as f64); - Some((batch_entries, batch, next_batch_span)) + (existing_entries_opt, Some((batch_entries, batch, next_batch_span))) + } + + /// Returns true if the entry at the front of the queue has been waiting for longer + /// than MAX_WAITING_DURATION + fn next_entry_waiting_too_long(&self) -> bool { + matches!( + self.entries.front(), Some((_, e)) if e.queue_time.elapsed() > MAX_WAITING_DURATION + ) } } +type ExistingBatch = IntMap; type NextBatch = (IntMap, Batch, Span); #[derive(Debug)] enum QueueCommand { Append(Entry, Span), NextBatch { - min_size: Option, - max_size: usize, - response_sender: oneshot::Sender>, + entries: Option, + response_sender: oneshot::Sender<(Option, Option)>, span: Span, }, } - -#[cfg(test)] -mod tests { - use super::*; - use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; - use tracing::info_span; - - fn default_entry() -> ( - Entry, - flume::Receiver>, - ) { - let (response_tx, receiver_tx) = flume::unbounded(); - - let entry = Entry { - request: ValidGenerateRequest { - inputs: "".to_string(), - truncate: 0, - parameters: NextTokenChooserParameters { - temperature: 0.0, - top_k: 0, - top_p: 0.0, - typical_p: 0.0, - do_sample: false, - seed: 0, - repetition_penalty: 0.0, - watermark: false, - }, - stopping_parameters: StoppingCriteriaParameters { - ignore_eos_token: false, - max_new_tokens: 0, - stop_sequences: vec![], - }, - }, - response_tx, - span: info_span!("entry"), - temp_span: None, - queue_time: Instant::now(), - batch_time: None, - }; - (entry, receiver_tx) - } - - #[test] - fn test_append() { - let mut state = State::new(); - let (entry, _guard) = default_entry(); - - assert_eq!(state.next_id, 0); - assert_eq!(state.entries.len(), 0); - - state.append(entry); - - assert_eq!(state.next_id, 1); - assert_eq!(state.entries.len(), 1); - let (id, _) = state.entries.remove(0).unwrap(); - assert_eq!(id, 0); - } - - #[test] - fn test_next_batch_empty() { - let mut state = State::new(); - - assert!(state.next_batch(None, 1).is_none()); - assert!(state.next_batch(Some(1), 1).is_none()); - } - - #[test] - fn test_next_batch_min_size() { - let mut state = State::new(); - let (entry1, _guard1) = default_entry(); - let (entry2, _guard2) = default_entry(); - state.append(entry1); - state.append(entry2); - - let (entries, batch, _) = state.next_batch(None, 2).unwrap(); - assert_eq!(entries.len(), 2); - assert!(entries.contains_key(&0)); - assert!(entries.contains_key(&1)); - assert!(entries.get(&0).unwrap().batch_time.is_some()); - assert!(entries.get(&1).unwrap().batch_time.is_some()); - assert_eq!(batch.id, 0); - assert_eq!(batch.size, 2); - - assert_eq!(state.next_id, 2); - assert_eq!(state.entries.len(), 0); - assert_eq!(state.next_batch_id, 1); - - let (entry3, _guard3) = default_entry(); - state.append(entry3); - - assert!(state.next_batch(Some(2), 2).is_none()); - - assert_eq!(state.next_id, 3); - assert_eq!(state.entries.len(), 1); - let (id, _) = state.entries.remove(0).unwrap(); - assert_eq!(id, 2); - } - - #[test] - fn test_next_batch_max_size() { - let mut state = State::new(); - let (entry1, _guard1) = default_entry(); - let (entry2, _guard2) = default_entry(); - state.append(entry1); - state.append(entry2); - - let (entries, batch, _) = state.next_batch(None, 1).unwrap(); - assert_eq!(entries.len(), 1); - assert!(entries.contains_key(&0)); - assert_eq!(batch.id, 0); - assert_eq!(batch.size, 1); - - assert_eq!(state.next_id, 2); - assert_eq!(state.entries.len(), 1); - assert_eq!(state.next_batch_id, 1); - - let (entry3, _guard3) = default_entry(); - state.append(entry3); - - let (entries, batch, _) = state.next_batch(None, 3).unwrap(); - assert_eq!(entries.len(), 2); - assert!(entries.contains_key(&1)); - assert!(entries.contains_key(&2)); - assert_eq!(batch.id, 1); - assert_eq!(batch.size, 2); - - assert_eq!(state.next_id, 3); - assert_eq!(state.entries.len(), 0); - assert_eq!(state.next_batch_id, 2); - } - - #[tokio::test] - async fn test_queue_append() { - let queue = Queue::new(); - let (entry, _guard) = default_entry(); - queue.append(entry); - } - - #[tokio::test] - async fn test_queue_next_batch_empty() { - let queue = Queue::new(); - - assert!(queue.next_batch(None, 1).await.is_none()); - assert!(queue.next_batch(Some(1), 1).await.is_none()); - } - - #[tokio::test] - async fn test_queue_next_batch_min_size() { - let queue = Queue::new(); - let (entry1, _guard1) = default_entry(); - let (entry2, _guard2) = default_entry(); - queue.append(entry1); - queue.append(entry2); - - let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap(); - assert_eq!(entries.len(), 2); - assert!(entries.contains_key(&0)); - assert!(entries.contains_key(&1)); - assert!(entries.get(&0).unwrap().batch_time.is_some()); - assert!(entries.get(&1).unwrap().batch_time.is_some()); - assert_eq!(batch.id, 0); - assert_eq!(batch.size, 2); - - let (entry3, _guard3) = default_entry(); - queue.append(entry3); - - assert!(queue.next_batch(Some(2), 2).await.is_none()); - } - - #[tokio::test] - async fn test_queue_next_batch_max_size() { - let queue = Queue::new(); - let (entry1, _guard1) = default_entry(); - let (entry2, _guard2) = default_entry(); - queue.append(entry1); - queue.append(entry2); - - let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap(); - assert_eq!(entries.len(), 1); - assert!(entries.contains_key(&0)); - assert_eq!(batch.id, 0); - assert_eq!(batch.size, 1); - - let (entry3, _guard3) = default_entry(); - queue.append(entry3); - - let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap(); - assert_eq!(entries.len(), 2); - assert!(entries.contains_key(&1)); - assert!(entries.contains_key(&2)); - assert_eq!(batch.id, 1); - assert_eq!(batch.size, 2); - } - - #[tokio::test] - async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(); - let (entry, _) = default_entry(); - queue.append(entry); - - assert!(queue.next_batch(None, 1).await.is_none()); - } -} diff --git a/router/src/server.rs b/router/src/server.rs index fee748e6..47c45857 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -505,6 +505,8 @@ pub async fn run( max_input_length: usize, max_total_tokens: usize, max_batch_size: usize, + max_batch_weight: Option, + max_prefill_weight: Option, max_waiting_tokens: usize, client: ShardedClient, tokenizer: Option, @@ -552,6 +554,15 @@ pub async fn run( )] struct ApiDoc; + // If max batch weight is not set, infer from max batch size and max seq length + let max_batch_weight = max_batch_weight + .unwrap_or(max_batch_size * max_total_tokens); + let max_prefill_weight = max_prefill_weight.unwrap_or_default(); + + if max_total_tokens > max_batch_weight { + panic!("max_total_tokens cannot be greater than max_batch_weight"); + } + // Create state let validation = Validation::new( validation_workers, @@ -565,6 +576,8 @@ pub async fn run( client, validation, max_batch_size, + max_batch_weight, + max_prefill_weight, max_waiting_tokens, max_concurrent_requests, ); diff --git a/router/src/validation.rs b/router/src/validation.rs index 5f1b89b9..a3bddda1 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,3 +1,4 @@ +use std::cmp::min; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; /// Payload validation logic use crate::{GenerateParameters, GenerateRequest}; @@ -69,7 +70,7 @@ impl Validation { inputs: String, truncate: Option, max_new_tokens: u32, - ) -> Result { + ) -> Result<(String, usize), ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { // Create response channel @@ -105,7 +106,7 @@ impl Validation { } metrics::histogram!("tgi_request_input_length", input_length as f64); - Ok(inputs) + Ok((inputs, input_length)) } // Return inputs without validation else { @@ -123,7 +124,7 @@ impl Validation { )); } - Ok(inputs) + Ok((inputs, min(truncate.unwrap_or(usize::MAX), self.max_input_length))) } } @@ -238,7 +239,7 @@ impl Validation { .unwrap_or(Ok(None))?; // Validate inputs - let inputs = self + let (inputs, input_length) = self .validate_input(request.inputs, truncate, max_new_tokens) .await?; @@ -262,7 +263,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, - truncate: truncate.unwrap_or(self.max_input_length) as u32, + truncate: input_length as u32, parameters, stopping_parameters, })