abstract for padded batch case

This commit is contained in:
Nick Hill 2023-04-27 07:53:09 +01:00
parent 47fb2fb986
commit 08aee68f79
5 changed files with 396 additions and 65 deletions

77
Cargo.lock generated
View File

@ -1438,6 +1438,82 @@ dependencies = [
"winapi",
]
[[package]]
name = "num"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606"
dependencies = [
"num-bigint",
"num-complex",
"num-integer",
"num-iter",
"num-rational",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-complex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-iter"
version = "0.1.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
dependencies = [
"autocfg",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
dependencies = [
"autocfg",
]
[[package]]
name = "num_cpus"
version = "1.15.0"
@ -2414,6 +2490,7 @@ dependencies = [
"metrics",
"metrics-exporter-prometheus",
"nohash-hasher",
"num",
"opentelemetry",
"opentelemetry-otlp",
"rand",

View File

@ -24,6 +24,7 @@ futures = "0.3.26"
metrics = "0.20.1"
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
nohash-hasher = "0.2.0"
num = "0.4.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0"
rand = "0.8.5"

View File

@ -15,15 +15,15 @@ use thiserror::Error;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span};
use crate::queue::BatchingConfig;
use crate::queue::{BatchingConfig, BatchType};
/// Inference struct
#[derive(Clone)]
pub struct Infer {
pub(crate) struct Infer<B: BatchType> {
/// Validation
validation: Validation,
/// Request queue
queue: Queue,
queue: Queue<B>,
/// Shared state
shared: Arc<Shared>,
/// Inference limit
@ -36,7 +36,7 @@ struct Shared {
batching_task: Notify,
}
impl Infer {
impl<B: BatchType> Infer<B> {
pub(crate) fn new(
client: ShardedClient,
validation: Validation,
@ -45,13 +45,14 @@ impl Infer {
max_prefill_weight: usize,
max_waiting_tokens: usize,
max_concurrent_requests: usize,
batch_type: B,
) -> Self {
// Infer shared state
let queue = Queue::new(BatchingConfig {
size_limit: max_batch_size,
weight_limit: max_batch_weight,
prefill_weight_limit: max_prefill_weight,
});
}, batch_type);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
@ -237,11 +238,11 @@ impl Infer {
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
async fn batching_task(
async fn batching_task<B: BatchType>(
mut client: ShardedClient,
// max_batch_size: usize,
max_waiting_tokens: usize,
queue: Queue,
queue: Queue<B>,
shared: Arc<Shared>,
) {
// Infinite loop

View File

@ -1,10 +1,13 @@
use std::cmp::max;
use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::collections::{BTreeSet, VecDeque};
use std::marker::PhantomData;
use std::ops::Add;
use std::time::Duration;
use num::integer::Roots;
use text_generation_client::{Batch, Request};
use tokio::sync::oneshot;
use tokio::time::Instant;
@ -31,20 +34,22 @@ pub(crate) struct Entry {
/// Request Queue
#[derive(Debug, Clone)]
pub(crate) struct Queue {
pub(crate) struct Queue<B: BatchType> {
/// Channel to communicate with the background queue task
queue_sender: flume::Sender<QueueCommand>,
/// Just for type inference
batch_type: PhantomData<B>,
}
impl Queue {
pub(crate) fn new(config: BatchingConfig) -> Self {
impl<B: BatchType> Queue<B> {
pub(crate) fn new(config: BatchingConfig, batch_type: B) -> Self {
// Create channel
let (queue_sender, queue_receiver) = flume::unbounded();
// Launch background queue task
tokio::spawn(queue_task(queue_receiver, config));
tokio::spawn(queue_task(queue_receiver, config, batch_type));
Self { queue_sender }
Self { queue_sender, batch_type: PhantomData }
}
/// Append an entry to the queue
@ -81,8 +86,10 @@ impl Queue {
}
// Background task responsible of the queue state
async fn queue_task(receiver: flume::Receiver<QueueCommand>, config: BatchingConfig) {
let mut state = State::new(config);
async fn queue_task<B: BatchType>(
receiver: flume::Receiver<QueueCommand>, config: BatchingConfig, batch_type: B
) {
let mut state = State::new(config, batch_type);
while let Ok(cmd) = receiver.recv_async().await {
match cmd {
@ -108,9 +115,10 @@ pub(crate) struct BatchingConfig {
/// Queue State
#[derive(Debug)]
struct State {
struct State<B: BatchType> {
/// Batching configuration
config: BatchingConfig,
batch_type: PhantomData<B>,
/// Queue entries organized in a Vec
entries: VecDeque<(u64, Entry)>,
@ -149,10 +157,133 @@ const MAX_WAITING_DURATION: Duration = Duration::from_secs(1);
/// ahead of larger ones in the queue
const CUTOFF_DURATION: Duration = Duration::from_secs(1);
impl State {
fn new(config: BatchingConfig) -> Self {
pub(crate) trait BatchType: Send + Sync + Clone + 'static {
type Stats: Default;
/// Update batch statistics with an additional request
fn update_stats(stats: &Self::Stats, input_length: usize, output_length: usize) -> Self::Stats;
/// Calculate batch weight given batch statistics
fn batch_weight(stats: &Self::Stats, batch_size: usize) -> usize;
/// Calculate prefill batch weight given prefill batch statistics
fn prefill_weight(prefill_stats: &Self::Stats, batch_size: usize) -> usize;
/// Indicate whether a hypothetical batch will exceed the combined weight limit
fn exceeds_weight(
tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize
) -> bool;
/// Compute batch statistics given map of entries
fn compute_stats(entries: &IntMap<u64, Entry>) -> Self::Stats {
entries.iter().fold(
Self::Stats::default(),
|stats, (_, e)| Self::update_stats(
&stats,
e.request.truncate as usize,
e.request.stopping_parameters.max_new_tokens as usize,
)
)
}
}
/// Non-padded batch used in flash attention
#[derive(Clone)]
pub(crate) struct FlashBatch {}
impl BatchType for FlashBatch {
/// Keep track of total number of tokens in the batch
type Stats = usize;
fn update_stats(
total_tokens: &Self::Stats, input_length: usize, output_length: usize
) -> Self::Stats {
total_tokens + input_length + output_length
}
fn batch_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
*total_tokens
}
fn prefill_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
*total_tokens
}
fn exceeds_weight(
tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize
) -> bool {
let mut in_sum = 0;
// Work backwards from longest projected entry
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
let this_ol = *ol;
in_sum += *il;
if this_ol <= current_output_len {
// Check if we breach max space for this segment
let token_count = in_sum + (bs + 1) * this_ol;
if token_count > max_total_weight {
return true
}
}
}
false
}
}
/// Regular rectangular padded
#[derive(Clone)]
pub(crate) struct PaddedBatch {}
impl BatchType for PaddedBatch {
/// Keep track of maximum input length, maximum output length
type Stats = (usize, usize);
fn update_stats(
max_in_out_lengths: &Self::Stats, input_length: usize, output_length: usize
) -> Self::Stats {
let (max_input_length, max_output_length) = max_in_out_lengths;
(max(*max_input_length, input_length), max(*max_output_length, output_length))
}
fn batch_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
let (max_input_length, max_output_length) = max_in_out_lengths;
let max_seq_len = max_input_length + max_output_length;
// Memory requirement roughly propotionall to batch_size * seq_len^2
batch_size * max_seq_len.pow(2)
}
fn prefill_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
// Empirically, prefill latency is proportional to batch_size * seq_len^(3/2)
let (max_input_length, _) = max_in_out_lengths;
batch_size * max_input_length.pow(3).sqrt()
}
fn exceeds_weight(
tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize
) -> bool {
let mut max_in = 0;
let mut last_ol = 0;
// Work backwards from longest projected entry
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
let this_ol = *ol;
if this_ol != last_ol {
max_in = max(max_in, *il);
if this_ol <= current_output_len {
// Check if we breach max space for this segment
let seq_len = max_in + this_ol;
if seq_len.pow(2) * (bs + 1) > max_total_weight {
return true
}
}
last_ol = this_ol;
}
}
false
}
}
impl<B: BatchType> State<B> {
fn new(config: BatchingConfig, _batch_type: B) -> Self {
Self {
config,
batch_type: PhantomData,
entries: VecDeque::with_capacity(128),
next_id: 0,
next_batch_id: 0,
@ -226,11 +357,8 @@ impl State {
let mut time_cutoff = None;
let mut hit_prefill_weight_limit = false;
let mut total_token_count = existing_entries.iter().map(
|(_, e)| e.request.stopping_parameters.max_new_tokens + e.request.truncate
).sum::<u32>() as usize;
let mut prefill_size = 0;
let mut batch_stats = <B as BatchType>::compute_stats(existing_entries);
let mut prefill_stats = <B as BatchType>::compute_stats(&self.empty_map);
// We first do a read-only pass over the queue to allow skipping over large entries
// that don't fit in the current batch to reach smaller entries that do
let mut queue_index = checked_up_to_index;
@ -252,10 +380,14 @@ impl State {
let input_len = entry.request.truncate as usize;
let output_len = entry.request.stopping_parameters.max_new_tokens as usize;
let next_total_token_count = total_token_count + input_len + output_len;
let next_stats = <B as BatchType>::update_stats(
&batch_stats, input_len, output_len
);
// Avoid more granular analysis if possible
if next_total_token_count > config.weight_limit {
if <B as BatchType>::batch_weight(
&batch_stats, total_count + 1
) > config.weight_limit {
// We aren't sure whether this next request will fit, so populate
// a btree with the current batch of requests, the set of
// requests already evaluated, and this one, and perform more
@ -283,22 +415,14 @@ impl State {
tree.insert((output_len, input_len, entry_id));
// Perform analysis
let mut in_sum = 0;
// Work backwards from longest projected entry
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
let this_ol = *ol;
in_sum += *il;
if this_ol <= output_len {
// Check if we breach max space for this segment
let token_count = in_sum + (bs + 1) * this_ol;
if token_count > config.weight_limit {
if <B as BatchType>::exceeds_weight(
tree, config.weight_limit, output_len,
) {
// Remove our tuple from the set
tree.remove(&(output_len, input_len, entry_id));
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
continue 'queue_loop
}
}
}
} else if let Some(tree) = btree.as_mut() {
// If we initialized the btree for a prior request, keep it updated
tree.insert((output_len, input_len, entry_id));
@ -308,7 +432,13 @@ impl State {
// Also check whether adding this request will make the batch of new requests
// too expensive latency-wise to perform in a single forward-pass.
if config.prefill_weight_limit > 0 {
if prefill_size + input_len > config.prefill_weight_limit {
let next_prefill_stats = <B as BatchType>::update_stats(
&prefill_stats, input_len, 0
);
let prefill_weight = <B as BatchType>::prefill_weight(
&next_prefill_stats, chosen_indices.len() + 1
);
if prefill_weight > config.prefill_weight_limit {
if let Some(tree) = btree.as_mut() {
// Remove our tuple from the set
tree.remove(&(output_len, input_len, entry_id));
@ -317,10 +447,10 @@ impl State {
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
continue
}
prefill_stats = next_prefill_stats;
}
total_token_count = next_total_token_count;
prefill_size += input_len;
batch_stats = next_stats;
chosen_indices.push(queue_index - 1);
total_count += 1;
@ -349,7 +479,7 @@ impl State {
// If this is to be added to an existing batch, ensure it meets urgency or size
// requirements to avoid too frequent prefills
if !self.next_entry_waiting_too_long() {
if total_token_count < config.weight_limit / 2 {
if <B as BatchType>::batch_weight(&batch_stats, total_count) < config.weight_limit / 2 {
// Don't add this new batch yet because it's not large enough
self.checked_request_count = checked_up_to_index;
self.buffer_contents_insufficient = true;

View File

@ -17,15 +17,17 @@ use futures::stream::StreamExt;
use futures::Stream;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible;
use std::marker::PhantomData;
use std::net::SocketAddr;
use text_generation_client::ShardedClient;
use tokenizers::Tokenizer;
use tokio::signal;
use tokio::time::Instant;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument};
use tracing::{info_span, instrument, Instrument, warn};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
use crate::queue::{BatchType, FlashBatch, PaddedBatch};
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path(
@ -46,9 +48,9 @@ use utoipa_swagger_ui::SwaggerUi;
)
)]
#[instrument(skip(infer))]
async fn compat_generate(
async fn compat_generate<B: BatchType>(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
infer: Extension<Infer<B>>,
req: Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let mut req = req.0;
@ -92,7 +94,9 @@ async fn get_model_info(model_info: Extension<ModelInfo>) -> Json<Info> {
/// Health check method
#[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
async fn health<B: BatchType>(
infer: Extension<Infer<B>>) -> Result<(), (StatusCode, Json<ErrorResponse>
)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead is check if the gRPC channels are still healthy.
@ -151,8 +155,8 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
seed,
)
)]
async fn generate(
infer: Extension<Infer>,
async fn generate<B: BatchType>(
infer: Extension<Infer<B>>,
req: Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
@ -333,8 +337,8 @@ async fn generate(
seed,
)
)]
async fn generate_stream(
infer: Extension<Infer>,
async fn generate_stream<B: BatchType + 'static> (
infer: Extension<Infer<B>>,
req: Json<GenerateRequest>,
) -> (
HeaderMap,
@ -494,6 +498,59 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render()
}
struct BatchConfigValidator<B: BatchType> {
batch_type: PhantomData<B>
}
impl<B: BatchType> BatchConfigValidator<B> {
fn validate_batch_config(
&self,
max_total_tokens: usize,
max_batch_size: usize,
max_batch_weight: Option<usize>,
max_prefill_weight: Option<usize>,
) -> (usize, usize) {
let single_request_stats = <B as BatchType>::update_stats(
&B::Stats::default(), max_total_tokens, 0
);
let single_request_weight = <B as BatchType>::batch_weight(
&single_request_stats, 1
);
let weight_upper_bound = single_request_weight * max_batch_size;
let max_prefill_weight = if let Some(max_prefill_weight) = max_prefill_weight {
let single_request_prefill_weight = <B as BatchType>::prefill_weight(
&single_request_stats, 1
);
if max_prefill_weight < single_request_prefill_weight {
panic!("max_prefill_weight not large enough for max_total_tokens")
}
max_prefill_weight
} else {
0
};
let max_batch_weight = if let Some(mut max_batch_weight) = max_batch_weight {
if max_batch_weight < single_request_weight {
panic!("max_batch_weight not large enough for max_total_tokens")
}
if max_batch_weight > weight_upper_bound {
warn!(
"Reducing specified max_batch_weight ({}) to ({}) which is an \
upper bound based on max_total_tokens ({}) and max_batch_size ({})",
max_batch_weight, weight_upper_bound, max_total_tokens, max_batch_size
);
max_batch_weight = weight_upper_bound
}
max_batch_weight
} else {
weight_upper_bound
};
(max_batch_weight, max_prefill_weight)
}
}
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
@ -513,6 +570,72 @@ pub async fn run(
validation_workers: usize,
addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
) {
//TODO get this from querying shard
let flash_attention = true;
if flash_attention {
do_run(
model_info,
compat_return_full_text,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
max_batch_weight,
max_prefill_weight,
max_waiting_tokens,
client,
tokenizer,
validation_workers,
addr,
allow_origin,
FlashBatch{},
).await
} else {
do_run(
model_info,
compat_return_full_text,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
max_batch_weight,
max_prefill_weight,
max_waiting_tokens,
client,
tokenizer,
validation_workers,
addr,
allow_origin,
PaddedBatch{},
).await
}
}
#[allow(clippy::too_many_arguments)]
async fn do_run<B: BatchType>(
model_info: ModelInfo,
compat_return_full_text: bool,
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
max_batch_size: usize,
max_batch_weight: Option<usize>,
max_prefill_weight: Option<usize>,
max_waiting_tokens: usize,
client: ShardedClient,
tokenizer: Option<Tokenizer>,
validation_workers: usize,
addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
_batch_type: B,
) {
// OpenAPI documentation
#[derive(OpenApi)]
@ -554,14 +677,12 @@ pub async fn run(
)]
struct ApiDoc;
// If max batch weight is not set, infer from max batch size and max seq length
let max_batch_weight = max_batch_weight
.unwrap_or(max_batch_size * max_total_tokens);
let max_prefill_weight = max_prefill_weight.unwrap_or_default();
let batch_config_validator = BatchConfigValidator::<B>{batch_type: PhantomData};
if max_total_tokens > max_batch_weight {
panic!("max_total_tokens cannot be greater than max_batch_weight");
}
// If max batch weight is not set, infer from max batch size and max seq length
let (max_batch_weight, max_prefill_weight) = batch_config_validator.validate_batch_config(
max_total_tokens, max_batch_size, max_batch_weight, max_prefill_weight,
);
// Create state
let validation = Validation::new(
@ -580,6 +701,7 @@ pub async fn run(
max_prefill_weight,
max_waiting_tokens,
max_concurrent_requests,
FlashBatch{}
);
// Duration buckets
@ -639,18 +761,18 @@ pub async fn run(
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
// Base routes
.route("/", post(compat_generate))
.route("/", post(compat_generate::<B>))
.route("/info", get(get_model_info))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/generate", post(generate::<B>))
.route("/generate_stream", post(generate_stream::<B>))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
.route("/invocations", post(compat_generate::<B>))
// Base Health route
.route("/health", get(health))
.route("/health", get(health::<B>))
// Inference API health route
.route("/", get(health))
.route("/", get(health::<B>))
// AWS Sagemaker health route
.route("/ping", get(health))
.route("/ping", get(health::<B>))
// Prometheus metrics route
.route("/metrics", get(metrics))
.layer(Extension(model_info))