From 08aee68f79a8206257b11c3fda50779db8dc2597 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 27 Apr 2023 07:53:09 +0100 Subject: [PATCH] abstract for padded batch case --- Cargo.lock | 77 +++++++++++++++++ router/Cargo.toml | 1 + router/src/infer.rs | 15 ++-- router/src/queue.rs | 202 +++++++++++++++++++++++++++++++++++-------- router/src/server.rs | 166 ++++++++++++++++++++++++++++++----- 5 files changed, 396 insertions(+), 65 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f1d2e3a2..09351ac8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1438,6 +1438,82 @@ dependencies = [ "winapi", ] +[[package]] +name = "num" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.15.0" @@ -2414,6 +2490,7 @@ dependencies = [ "metrics", "metrics-exporter-prometheus", "nohash-hasher", + "num", "opentelemetry", "opentelemetry-otlp", "rand", diff --git a/router/Cargo.toml b/router/Cargo.toml index aa8e9df2..e8a20690 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -24,6 +24,7 @@ futures = "0.3.26" metrics = "0.20.1" metrics-exporter-prometheus = { version = "0.11.0", features = [] } nohash-hasher = "0.2.0" +num = "0.4.0" opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.11.0" rand = "0.8.5" diff --git a/router/src/infer.rs b/router/src/infer.rs index 21c34d0b..60941fa6 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -15,15 +15,15 @@ use thiserror::Error; use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; use tracing::{info_span, instrument, Instrument, Span}; -use crate::queue::BatchingConfig; +use crate::queue::{BatchingConfig, BatchType}; /// Inference struct #[derive(Clone)] -pub struct Infer { +pub(crate) struct Infer { /// Validation validation: Validation, /// Request queue - queue: Queue, + queue: Queue, /// Shared state shared: Arc, /// Inference limit @@ -36,7 +36,7 @@ struct Shared { batching_task: Notify, } -impl Infer { +impl Infer { pub(crate) fn new( client: ShardedClient, validation: Validation, @@ -45,13 +45,14 @@ impl Infer { max_prefill_weight: usize, max_waiting_tokens: usize, max_concurrent_requests: usize, + batch_type: B, ) -> Self { // Infer shared state let queue = Queue::new(BatchingConfig { size_limit: max_batch_size, weight_limit: max_batch_weight, prefill_weight_limit: max_prefill_weight, - }); + }, batch_type); let shared = Arc::new(Shared { batching_task: Notify::new(), }); @@ -237,11 +238,11 @@ impl Infer { /// Will be launched in a background Tokio task /// /// Batches requests and sends them to the inference server -async fn batching_task( +async fn batching_task( mut client: ShardedClient, // max_batch_size: usize, max_waiting_tokens: usize, - queue: Queue, + queue: Queue, shared: Arc, ) { // Infinite loop diff --git a/router/src/queue.rs b/router/src/queue.rs index 5496224d..098a337a 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -1,10 +1,13 @@ +use std::cmp::max; use crate::infer::InferError; use crate::infer::InferStreamResponse; use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::collections::{BTreeSet, VecDeque}; +use std::marker::PhantomData; use std::ops::Add; use std::time::Duration; +use num::integer::Roots; use text_generation_client::{Batch, Request}; use tokio::sync::oneshot; use tokio::time::Instant; @@ -31,20 +34,22 @@ pub(crate) struct Entry { /// Request Queue #[derive(Debug, Clone)] -pub(crate) struct Queue { +pub(crate) struct Queue { /// Channel to communicate with the background queue task queue_sender: flume::Sender, + /// Just for type inference + batch_type: PhantomData, } -impl Queue { - pub(crate) fn new(config: BatchingConfig) -> Self { +impl Queue { + pub(crate) fn new(config: BatchingConfig, batch_type: B) -> Self { // Create channel let (queue_sender, queue_receiver) = flume::unbounded(); // Launch background queue task - tokio::spawn(queue_task(queue_receiver, config)); + tokio::spawn(queue_task(queue_receiver, config, batch_type)); - Self { queue_sender } + Self { queue_sender, batch_type: PhantomData } } /// Append an entry to the queue @@ -81,8 +86,10 @@ impl Queue { } // Background task responsible of the queue state -async fn queue_task(receiver: flume::Receiver, config: BatchingConfig) { - let mut state = State::new(config); +async fn queue_task( + receiver: flume::Receiver, config: BatchingConfig, batch_type: B +) { + let mut state = State::new(config, batch_type); while let Ok(cmd) = receiver.recv_async().await { match cmd { @@ -108,9 +115,10 @@ pub(crate) struct BatchingConfig { /// Queue State #[derive(Debug)] -struct State { +struct State { /// Batching configuration config: BatchingConfig, + batch_type: PhantomData, /// Queue entries organized in a Vec entries: VecDeque<(u64, Entry)>, @@ -149,10 +157,133 @@ const MAX_WAITING_DURATION: Duration = Duration::from_secs(1); /// ahead of larger ones in the queue const CUTOFF_DURATION: Duration = Duration::from_secs(1); -impl State { - fn new(config: BatchingConfig) -> Self { +pub(crate) trait BatchType: Send + Sync + Clone + 'static { + type Stats: Default; + + /// Update batch statistics with an additional request + fn update_stats(stats: &Self::Stats, input_length: usize, output_length: usize) -> Self::Stats; + /// Calculate batch weight given batch statistics + fn batch_weight(stats: &Self::Stats, batch_size: usize) -> usize; + /// Calculate prefill batch weight given prefill batch statistics + fn prefill_weight(prefill_stats: &Self::Stats, batch_size: usize) -> usize; + /// Indicate whether a hypothetical batch will exceed the combined weight limit + fn exceeds_weight( + tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize + ) -> bool; + + /// Compute batch statistics given map of entries + fn compute_stats(entries: &IntMap) -> Self::Stats { + entries.iter().fold( + Self::Stats::default(), + |stats, (_, e)| Self::update_stats( + &stats, + e.request.truncate as usize, + e.request.stopping_parameters.max_new_tokens as usize, + ) + ) + } +} + +/// Non-padded batch used in flash attention +#[derive(Clone)] +pub(crate) struct FlashBatch {} + +impl BatchType for FlashBatch { + /// Keep track of total number of tokens in the batch + type Stats = usize; + + fn update_stats( + total_tokens: &Self::Stats, input_length: usize, output_length: usize + ) -> Self::Stats { + total_tokens + input_length + output_length + } + + fn batch_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize { + *total_tokens + } + + fn prefill_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize { + *total_tokens + } + + fn exceeds_weight( + tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize + ) -> bool { + let mut in_sum = 0; + // Work backwards from longest projected entry + for (bs, (ol, il, _)) in tree.iter().rev().enumerate() { + let this_ol = *ol; + in_sum += *il; + if this_ol <= current_output_len { + // Check if we breach max space for this segment + let token_count = in_sum + (bs + 1) * this_ol; + if token_count > max_total_weight { + return true + } + } + } + false + } +} + +/// Regular rectangular padded +#[derive(Clone)] +pub(crate) struct PaddedBatch {} + +impl BatchType for PaddedBatch { + /// Keep track of maximum input length, maximum output length + type Stats = (usize, usize); + + fn update_stats( + max_in_out_lengths: &Self::Stats, input_length: usize, output_length: usize + ) -> Self::Stats { + let (max_input_length, max_output_length) = max_in_out_lengths; + (max(*max_input_length, input_length), max(*max_output_length, output_length)) + } + + fn batch_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize { + let (max_input_length, max_output_length) = max_in_out_lengths; + let max_seq_len = max_input_length + max_output_length; + // Memory requirement roughly propotionall to batch_size * seq_len^2 + batch_size * max_seq_len.pow(2) + } + + fn prefill_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize { + // Empirically, prefill latency is proportional to batch_size * seq_len^(3/2) + let (max_input_length, _) = max_in_out_lengths; + batch_size * max_input_length.pow(3).sqrt() + } + + fn exceeds_weight( + tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize + ) -> bool { + let mut max_in = 0; + let mut last_ol = 0; + // Work backwards from longest projected entry + for (bs, (ol, il, _)) in tree.iter().rev().enumerate() { + let this_ol = *ol; + if this_ol != last_ol { + max_in = max(max_in, *il); + if this_ol <= current_output_len { + // Check if we breach max space for this segment + let seq_len = max_in + this_ol; + if seq_len.pow(2) * (bs + 1) > max_total_weight { + return true + } + } + last_ol = this_ol; + } + } + false + } +} + + +impl State { + fn new(config: BatchingConfig, _batch_type: B) -> Self { Self { config, + batch_type: PhantomData, entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, @@ -226,11 +357,8 @@ impl State { let mut time_cutoff = None; let mut hit_prefill_weight_limit = false; - let mut total_token_count = existing_entries.iter().map( - |(_, e)| e.request.stopping_parameters.max_new_tokens + e.request.truncate - ).sum::() as usize; - - let mut prefill_size = 0; + let mut batch_stats = ::compute_stats(existing_entries); + let mut prefill_stats = ::compute_stats(&self.empty_map); // We first do a read-only pass over the queue to allow skipping over large entries // that don't fit in the current batch to reach smaller entries that do let mut queue_index = checked_up_to_index; @@ -252,10 +380,14 @@ impl State { let input_len = entry.request.truncate as usize; let output_len = entry.request.stopping_parameters.max_new_tokens as usize; - let next_total_token_count = total_token_count + input_len + output_len; + let next_stats = ::update_stats( + &batch_stats, input_len, output_len + ); // Avoid more granular analysis if possible - if next_total_token_count > config.weight_limit { + if ::batch_weight( + &batch_stats, total_count + 1 + ) > config.weight_limit { // We aren't sure whether this next request will fit, so populate // a btree with the current batch of requests, the set of // requests already evaluated, and this one, and perform more @@ -283,21 +415,13 @@ impl State { tree.insert((output_len, input_len, entry_id)); // Perform analysis - let mut in_sum = 0; - // Work backwards from longest projected entry - for (bs, (ol, il, _)) in tree.iter().rev().enumerate() { - let this_ol = *ol; - in_sum += *il; - if this_ol <= output_len { - // Check if we breach max space for this segment - let token_count = in_sum + (bs + 1) * this_ol; - if token_count > config.weight_limit { - // Remove our tuple from the set - tree.remove(&(output_len, input_len, entry_id)); - time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION)); - continue 'queue_loop - } - } + if ::exceeds_weight( + tree, config.weight_limit, output_len, + ) { + // Remove our tuple from the set + tree.remove(&(output_len, input_len, entry_id)); + time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION)); + continue 'queue_loop } } else if let Some(tree) = btree.as_mut() { // If we initialized the btree for a prior request, keep it updated @@ -308,7 +432,13 @@ impl State { // Also check whether adding this request will make the batch of new requests // too expensive latency-wise to perform in a single forward-pass. if config.prefill_weight_limit > 0 { - if prefill_size + input_len > config.prefill_weight_limit { + let next_prefill_stats = ::update_stats( + &prefill_stats, input_len, 0 + ); + let prefill_weight = ::prefill_weight( + &next_prefill_stats, chosen_indices.len() + 1 + ); + if prefill_weight > config.prefill_weight_limit { if let Some(tree) = btree.as_mut() { // Remove our tuple from the set tree.remove(&(output_len, input_len, entry_id)); @@ -317,10 +447,10 @@ impl State { time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION)); continue } + prefill_stats = next_prefill_stats; } - total_token_count = next_total_token_count; - prefill_size += input_len; + batch_stats = next_stats; chosen_indices.push(queue_index - 1); total_count += 1; @@ -349,7 +479,7 @@ impl State { // If this is to be added to an existing batch, ensure it meets urgency or size // requirements to avoid too frequent prefills if !self.next_entry_waiting_too_long() { - if total_token_count < config.weight_limit / 2 { + if ::batch_weight(&batch_stats, total_count) < config.weight_limit / 2 { // Don't add this new batch yet because it's not large enough self.checked_request_count = checked_up_to_index; self.buffer_contents_insufficient = true; diff --git a/router/src/server.rs b/router/src/server.rs index 47c45857..817dcbba 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -17,15 +17,17 @@ use futures::stream::StreamExt; use futures::Stream; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; +use std::marker::PhantomData; use std::net::SocketAddr; use text_generation_client::ShardedClient; use tokenizers::Tokenizer; use tokio::signal; use tokio::time::Instant; use tower_http::cors::{AllowOrigin, CorsLayer}; -use tracing::{info_span, instrument, Instrument}; +use tracing::{info_span, instrument, Instrument, warn}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; +use crate::queue::{BatchType, FlashBatch, PaddedBatch}; /// Generate tokens if `stream == false` or a stream of token if `stream == true` #[utoipa::path( @@ -46,9 +48,9 @@ use utoipa_swagger_ui::SwaggerUi; ) )] #[instrument(skip(infer))] -async fn compat_generate( +async fn compat_generate( default_return_full_text: Extension, - infer: Extension, + infer: Extension>, req: Json, ) -> Result)> { let mut req = req.0; @@ -92,7 +94,9 @@ async fn get_model_info(model_info: Extension) -> Json { /// Health check method #[instrument(skip(infer))] -async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { +async fn health( + infer: Extension>) -> Result<(), (StatusCode, Json +)> { // TODO: while this is the best health check we can do, it is a bit on the heavy side and might // be a bit too slow for a health check. // What we should do instead is check if the gRPC channels are still healthy. @@ -151,8 +155,8 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json, +async fn generate( + infer: Extension>, req: Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); @@ -333,8 +337,8 @@ async fn generate( seed, ) )] -async fn generate_stream( - infer: Extension, +async fn generate_stream ( + infer: Extension>, req: Json, ) -> ( HeaderMap, @@ -494,6 +498,59 @@ async fn metrics(prom_handle: Extension) -> String { prom_handle.render() } +struct BatchConfigValidator { + batch_type: PhantomData +} + +impl BatchConfigValidator { + fn validate_batch_config( + &self, + max_total_tokens: usize, + max_batch_size: usize, + max_batch_weight: Option, + max_prefill_weight: Option, + ) -> (usize, usize) { + let single_request_stats = ::update_stats( + &B::Stats::default(), max_total_tokens, 0 + ); + let single_request_weight = ::batch_weight( + &single_request_stats, 1 + ); + let weight_upper_bound = single_request_weight * max_batch_size; + + let max_prefill_weight = if let Some(max_prefill_weight) = max_prefill_weight { + let single_request_prefill_weight = ::prefill_weight( + &single_request_stats, 1 + ); + if max_prefill_weight < single_request_prefill_weight { + panic!("max_prefill_weight not large enough for max_total_tokens") + } + max_prefill_weight + } else { + 0 + }; + + let max_batch_weight = if let Some(mut max_batch_weight) = max_batch_weight { + if max_batch_weight < single_request_weight { + panic!("max_batch_weight not large enough for max_total_tokens") + } + if max_batch_weight > weight_upper_bound { + warn!( + "Reducing specified max_batch_weight ({}) to ({}) which is an \ + upper bound based on max_total_tokens ({}) and max_batch_size ({})", + max_batch_weight, weight_upper_bound, max_total_tokens, max_batch_size + ); + max_batch_weight = weight_upper_bound + } + max_batch_weight + } else { + weight_upper_bound + }; + + (max_batch_weight, max_prefill_weight) + } +} + /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( @@ -513,6 +570,72 @@ pub async fn run( validation_workers: usize, addr: SocketAddr, allow_origin: Option, +) { + //TODO get this from querying shard + let flash_attention = true; + + if flash_attention { + do_run( + model_info, + compat_return_full_text, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_input_length, + max_total_tokens, + max_batch_size, + max_batch_weight, + max_prefill_weight, + max_waiting_tokens, + client, + tokenizer, + validation_workers, + addr, + allow_origin, + FlashBatch{}, + ).await + } else { + do_run( + model_info, + compat_return_full_text, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_input_length, + max_total_tokens, + max_batch_size, + max_batch_weight, + max_prefill_weight, + max_waiting_tokens, + client, + tokenizer, + validation_workers, + addr, + allow_origin, + PaddedBatch{}, + ).await + } +} + +#[allow(clippy::too_many_arguments)] +async fn do_run( + model_info: ModelInfo, + compat_return_full_text: bool, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_input_length: usize, + max_total_tokens: usize, + max_batch_size: usize, + max_batch_weight: Option, + max_prefill_weight: Option, + max_waiting_tokens: usize, + client: ShardedClient, + tokenizer: Option, + validation_workers: usize, + addr: SocketAddr, + allow_origin: Option, + _batch_type: B, ) { // OpenAPI documentation #[derive(OpenApi)] @@ -554,14 +677,12 @@ pub async fn run( )] struct ApiDoc; - // If max batch weight is not set, infer from max batch size and max seq length - let max_batch_weight = max_batch_weight - .unwrap_or(max_batch_size * max_total_tokens); - let max_prefill_weight = max_prefill_weight.unwrap_or_default(); + let batch_config_validator = BatchConfigValidator::{batch_type: PhantomData}; - if max_total_tokens > max_batch_weight { - panic!("max_total_tokens cannot be greater than max_batch_weight"); - } + // If max batch weight is not set, infer from max batch size and max seq length + let (max_batch_weight, max_prefill_weight) = batch_config_validator.validate_batch_config( + max_total_tokens, max_batch_size, max_batch_weight, max_prefill_weight, + ); // Create state let validation = Validation::new( @@ -580,6 +701,7 @@ pub async fn run( max_prefill_weight, max_waiting_tokens, max_concurrent_requests, + FlashBatch{} ); // Duration buckets @@ -639,18 +761,18 @@ pub async fn run( let app = Router::new() .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) // Base routes - .route("/", post(compat_generate)) + .route("/", post(compat_generate::)) .route("/info", get(get_model_info)) - .route("/generate", post(generate)) - .route("/generate_stream", post(generate_stream)) + .route("/generate", post(generate::)) + .route("/generate_stream", post(generate_stream::)) // AWS Sagemaker route - .route("/invocations", post(compat_generate)) + .route("/invocations", post(compat_generate::)) // Base Health route - .route("/health", get(health)) + .route("/health", get(health::)) // Inference API health route - .route("/", get(health)) + .route("/", get(health::)) // AWS Sagemaker health route - .route("/ping", get(health)) + .route("/ping", get(health::)) // Prometheus metrics route .route("/metrics", get(metrics)) .layer(Extension(model_info))