abstract for padded batch case

This commit is contained in:
Nick Hill 2023-04-27 07:53:09 +01:00
parent 47fb2fb986
commit 08aee68f79
5 changed files with 396 additions and 65 deletions

77
Cargo.lock generated
View File

@ -1438,6 +1438,82 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "num"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606"
dependencies = [
"num-bigint",
"num-complex",
"num-integer",
"num-iter",
"num-rational",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-complex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-iter"
version = "0.1.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
dependencies = [
"autocfg",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "num_cpus" name = "num_cpus"
version = "1.15.0" version = "1.15.0"
@ -2414,6 +2490,7 @@ dependencies = [
"metrics", "metrics",
"metrics-exporter-prometheus", "metrics-exporter-prometheus",
"nohash-hasher", "nohash-hasher",
"num",
"opentelemetry", "opentelemetry",
"opentelemetry-otlp", "opentelemetry-otlp",
"rand", "rand",

View File

@ -24,6 +24,7 @@ futures = "0.3.26"
metrics = "0.20.1" metrics = "0.20.1"
metrics-exporter-prometheus = { version = "0.11.0", features = [] } metrics-exporter-prometheus = { version = "0.11.0", features = [] }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
num = "0.4.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0" opentelemetry-otlp = "0.11.0"
rand = "0.8.5" rand = "0.8.5"

View File

@ -15,15 +15,15 @@ use thiserror::Error;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
use crate::queue::BatchingConfig; use crate::queue::{BatchingConfig, BatchType};
/// Inference struct /// Inference struct
#[derive(Clone)] #[derive(Clone)]
pub struct Infer { pub(crate) struct Infer<B: BatchType> {
/// Validation /// Validation
validation: Validation, validation: Validation,
/// Request queue /// Request queue
queue: Queue, queue: Queue<B>,
/// Shared state /// Shared state
shared: Arc<Shared>, shared: Arc<Shared>,
/// Inference limit /// Inference limit
@ -36,7 +36,7 @@ struct Shared {
batching_task: Notify, batching_task: Notify,
} }
impl Infer { impl<B: BatchType> Infer<B> {
pub(crate) fn new( pub(crate) fn new(
client: ShardedClient, client: ShardedClient,
validation: Validation, validation: Validation,
@ -45,13 +45,14 @@ impl Infer {
max_prefill_weight: usize, max_prefill_weight: usize,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_concurrent_requests: usize, max_concurrent_requests: usize,
batch_type: B,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(BatchingConfig { let queue = Queue::new(BatchingConfig {
size_limit: max_batch_size, size_limit: max_batch_size,
weight_limit: max_batch_weight, weight_limit: max_batch_weight,
prefill_weight_limit: max_prefill_weight, prefill_weight_limit: max_prefill_weight,
}); }, batch_type);
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
batching_task: Notify::new(), batching_task: Notify::new(),
}); });
@ -237,11 +238,11 @@ impl Infer {
/// Will be launched in a background Tokio task /// Will be launched in a background Tokio task
/// ///
/// Batches requests and sends them to the inference server /// Batches requests and sends them to the inference server
async fn batching_task( async fn batching_task<B: BatchType>(
mut client: ShardedClient, mut client: ShardedClient,
// max_batch_size: usize, // max_batch_size: usize,
max_waiting_tokens: usize, max_waiting_tokens: usize,
queue: Queue, queue: Queue<B>,
shared: Arc<Shared>, shared: Arc<Shared>,
) { ) {
// Infinite loop // Infinite loop

View File

@ -1,10 +1,13 @@
use std::cmp::max;
use crate::infer::InferError; use crate::infer::InferError;
use crate::infer::InferStreamResponse; use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::collections::{BTreeSet, VecDeque}; use std::collections::{BTreeSet, VecDeque};
use std::marker::PhantomData;
use std::ops::Add; use std::ops::Add;
use std::time::Duration; use std::time::Duration;
use num::integer::Roots;
use text_generation_client::{Batch, Request}; use text_generation_client::{Batch, Request};
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio::time::Instant; use tokio::time::Instant;
@ -31,20 +34,22 @@ pub(crate) struct Entry {
/// Request Queue /// Request Queue
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct Queue { pub(crate) struct Queue<B: BatchType> {
/// Channel to communicate with the background queue task /// Channel to communicate with the background queue task
queue_sender: flume::Sender<QueueCommand>, queue_sender: flume::Sender<QueueCommand>,
/// Just for type inference
batch_type: PhantomData<B>,
} }
impl Queue { impl<B: BatchType> Queue<B> {
pub(crate) fn new(config: BatchingConfig) -> Self { pub(crate) fn new(config: BatchingConfig, batch_type: B) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = flume::unbounded(); let (queue_sender, queue_receiver) = flume::unbounded();
// Launch background queue task // Launch background queue task
tokio::spawn(queue_task(queue_receiver, config)); tokio::spawn(queue_task(queue_receiver, config, batch_type));
Self { queue_sender } Self { queue_sender, batch_type: PhantomData }
} }
/// Append an entry to the queue /// Append an entry to the queue
@ -81,8 +86,10 @@ impl Queue {
} }
// Background task responsible of the queue state // Background task responsible of the queue state
async fn queue_task(receiver: flume::Receiver<QueueCommand>, config: BatchingConfig) { async fn queue_task<B: BatchType>(
let mut state = State::new(config); receiver: flume::Receiver<QueueCommand>, config: BatchingConfig, batch_type: B
) {
let mut state = State::new(config, batch_type);
while let Ok(cmd) = receiver.recv_async().await { while let Ok(cmd) = receiver.recv_async().await {
match cmd { match cmd {
@ -108,9 +115,10 @@ pub(crate) struct BatchingConfig {
/// Queue State /// Queue State
#[derive(Debug)] #[derive(Debug)]
struct State { struct State<B: BatchType> {
/// Batching configuration /// Batching configuration
config: BatchingConfig, config: BatchingConfig,
batch_type: PhantomData<B>,
/// Queue entries organized in a Vec /// Queue entries organized in a Vec
entries: VecDeque<(u64, Entry)>, entries: VecDeque<(u64, Entry)>,
@ -149,10 +157,133 @@ const MAX_WAITING_DURATION: Duration = Duration::from_secs(1);
/// ahead of larger ones in the queue /// ahead of larger ones in the queue
const CUTOFF_DURATION: Duration = Duration::from_secs(1); const CUTOFF_DURATION: Duration = Duration::from_secs(1);
impl State { pub(crate) trait BatchType: Send + Sync + Clone + 'static {
fn new(config: BatchingConfig) -> Self { type Stats: Default;
/// Update batch statistics with an additional request
fn update_stats(stats: &Self::Stats, input_length: usize, output_length: usize) -> Self::Stats;
/// Calculate batch weight given batch statistics
fn batch_weight(stats: &Self::Stats, batch_size: usize) -> usize;
/// Calculate prefill batch weight given prefill batch statistics
fn prefill_weight(prefill_stats: &Self::Stats, batch_size: usize) -> usize;
/// Indicate whether a hypothetical batch will exceed the combined weight limit
fn exceeds_weight(
tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize
) -> bool;
/// Compute batch statistics given map of entries
fn compute_stats(entries: &IntMap<u64, Entry>) -> Self::Stats {
entries.iter().fold(
Self::Stats::default(),
|stats, (_, e)| Self::update_stats(
&stats,
e.request.truncate as usize,
e.request.stopping_parameters.max_new_tokens as usize,
)
)
}
}
/// Non-padded batch used in flash attention
#[derive(Clone)]
pub(crate) struct FlashBatch {}
impl BatchType for FlashBatch {
/// Keep track of total number of tokens in the batch
type Stats = usize;
fn update_stats(
total_tokens: &Self::Stats, input_length: usize, output_length: usize
) -> Self::Stats {
total_tokens + input_length + output_length
}
fn batch_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
*total_tokens
}
fn prefill_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
*total_tokens
}
fn exceeds_weight(
tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize
) -> bool {
let mut in_sum = 0;
// Work backwards from longest projected entry
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
let this_ol = *ol;
in_sum += *il;
if this_ol <= current_output_len {
// Check if we breach max space for this segment
let token_count = in_sum + (bs + 1) * this_ol;
if token_count > max_total_weight {
return true
}
}
}
false
}
}
/// Regular rectangular padded
#[derive(Clone)]
pub(crate) struct PaddedBatch {}
impl BatchType for PaddedBatch {
/// Keep track of maximum input length, maximum output length
type Stats = (usize, usize);
fn update_stats(
max_in_out_lengths: &Self::Stats, input_length: usize, output_length: usize
) -> Self::Stats {
let (max_input_length, max_output_length) = max_in_out_lengths;
(max(*max_input_length, input_length), max(*max_output_length, output_length))
}
fn batch_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
let (max_input_length, max_output_length) = max_in_out_lengths;
let max_seq_len = max_input_length + max_output_length;
// Memory requirement roughly propotionall to batch_size * seq_len^2
batch_size * max_seq_len.pow(2)
}
fn prefill_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
// Empirically, prefill latency is proportional to batch_size * seq_len^(3/2)
let (max_input_length, _) = max_in_out_lengths;
batch_size * max_input_length.pow(3).sqrt()
}
fn exceeds_weight(
tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize
) -> bool {
let mut max_in = 0;
let mut last_ol = 0;
// Work backwards from longest projected entry
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
let this_ol = *ol;
if this_ol != last_ol {
max_in = max(max_in, *il);
if this_ol <= current_output_len {
// Check if we breach max space for this segment
let seq_len = max_in + this_ol;
if seq_len.pow(2) * (bs + 1) > max_total_weight {
return true
}
}
last_ol = this_ol;
}
}
false
}
}
impl<B: BatchType> State<B> {
fn new(config: BatchingConfig, _batch_type: B) -> Self {
Self { Self {
config, config,
batch_type: PhantomData,
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
@ -226,11 +357,8 @@ impl State {
let mut time_cutoff = None; let mut time_cutoff = None;
let mut hit_prefill_weight_limit = false; let mut hit_prefill_weight_limit = false;
let mut total_token_count = existing_entries.iter().map( let mut batch_stats = <B as BatchType>::compute_stats(existing_entries);
|(_, e)| e.request.stopping_parameters.max_new_tokens + e.request.truncate let mut prefill_stats = <B as BatchType>::compute_stats(&self.empty_map);
).sum::<u32>() as usize;
let mut prefill_size = 0;
// We first do a read-only pass over the queue to allow skipping over large entries // We first do a read-only pass over the queue to allow skipping over large entries
// that don't fit in the current batch to reach smaller entries that do // that don't fit in the current batch to reach smaller entries that do
let mut queue_index = checked_up_to_index; let mut queue_index = checked_up_to_index;
@ -252,10 +380,14 @@ impl State {
let input_len = entry.request.truncate as usize; let input_len = entry.request.truncate as usize;
let output_len = entry.request.stopping_parameters.max_new_tokens as usize; let output_len = entry.request.stopping_parameters.max_new_tokens as usize;
let next_total_token_count = total_token_count + input_len + output_len; let next_stats = <B as BatchType>::update_stats(
&batch_stats, input_len, output_len
);
// Avoid more granular analysis if possible // Avoid more granular analysis if possible
if next_total_token_count > config.weight_limit { if <B as BatchType>::batch_weight(
&batch_stats, total_count + 1
) > config.weight_limit {
// We aren't sure whether this next request will fit, so populate // We aren't sure whether this next request will fit, so populate
// a btree with the current batch of requests, the set of // a btree with the current batch of requests, the set of
// requests already evaluated, and this one, and perform more // requests already evaluated, and this one, and perform more
@ -283,21 +415,13 @@ impl State {
tree.insert((output_len, input_len, entry_id)); tree.insert((output_len, input_len, entry_id));
// Perform analysis // Perform analysis
let mut in_sum = 0; if <B as BatchType>::exceeds_weight(
// Work backwards from longest projected entry tree, config.weight_limit, output_len,
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() { ) {
let this_ol = *ol; // Remove our tuple from the set
in_sum += *il; tree.remove(&(output_len, input_len, entry_id));
if this_ol <= output_len { time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
// Check if we breach max space for this segment continue 'queue_loop
let token_count = in_sum + (bs + 1) * this_ol;
if token_count > config.weight_limit {
// Remove our tuple from the set
tree.remove(&(output_len, input_len, entry_id));
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
continue 'queue_loop
}
}
} }
} else if let Some(tree) = btree.as_mut() { } else if let Some(tree) = btree.as_mut() {
// If we initialized the btree for a prior request, keep it updated // If we initialized the btree for a prior request, keep it updated
@ -308,7 +432,13 @@ impl State {
// Also check whether adding this request will make the batch of new requests // Also check whether adding this request will make the batch of new requests
// too expensive latency-wise to perform in a single forward-pass. // too expensive latency-wise to perform in a single forward-pass.
if config.prefill_weight_limit > 0 { if config.prefill_weight_limit > 0 {
if prefill_size + input_len > config.prefill_weight_limit { let next_prefill_stats = <B as BatchType>::update_stats(
&prefill_stats, input_len, 0
);
let prefill_weight = <B as BatchType>::prefill_weight(
&next_prefill_stats, chosen_indices.len() + 1
);
if prefill_weight > config.prefill_weight_limit {
if let Some(tree) = btree.as_mut() { if let Some(tree) = btree.as_mut() {
// Remove our tuple from the set // Remove our tuple from the set
tree.remove(&(output_len, input_len, entry_id)); tree.remove(&(output_len, input_len, entry_id));
@ -317,10 +447,10 @@ impl State {
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION)); time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
continue continue
} }
prefill_stats = next_prefill_stats;
} }
total_token_count = next_total_token_count; batch_stats = next_stats;
prefill_size += input_len;
chosen_indices.push(queue_index - 1); chosen_indices.push(queue_index - 1);
total_count += 1; total_count += 1;
@ -349,7 +479,7 @@ impl State {
// If this is to be added to an existing batch, ensure it meets urgency or size // If this is to be added to an existing batch, ensure it meets urgency or size
// requirements to avoid too frequent prefills // requirements to avoid too frequent prefills
if !self.next_entry_waiting_too_long() { if !self.next_entry_waiting_too_long() {
if total_token_count < config.weight_limit / 2 { if <B as BatchType>::batch_weight(&batch_stats, total_count) < config.weight_limit / 2 {
// Don't add this new batch yet because it's not large enough // Don't add this new batch yet because it's not large enough
self.checked_request_count = checked_up_to_index; self.checked_request_count = checked_up_to_index;
self.buffer_contents_insufficient = true; self.buffer_contents_insufficient = true;

View File

@ -17,15 +17,17 @@ use futures::stream::StreamExt;
use futures::Stream; use futures::Stream;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible; use std::convert::Infallible;
use std::marker::PhantomData;
use std::net::SocketAddr; use std::net::SocketAddr;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::time::Instant; use tokio::time::Instant;
use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument, warn};
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
use crate::queue::{BatchType, FlashBatch, PaddedBatch};
/// Generate tokens if `stream == false` or a stream of token if `stream == true` /// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path( #[utoipa::path(
@ -46,9 +48,9 @@ use utoipa_swagger_ui::SwaggerUi;
) )
)] )]
#[instrument(skip(infer))] #[instrument(skip(infer))]
async fn compat_generate( async fn compat_generate<B: BatchType>(
default_return_full_text: Extension<bool>, default_return_full_text: Extension<bool>,
infer: Extension<Infer>, infer: Extension<Infer<B>>,
req: Json<CompatGenerateRequest>, req: Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let mut req = req.0; let mut req = req.0;
@ -92,7 +94,9 @@ async fn get_model_info(model_info: Extension<ModelInfo>) -> Json<Info> {
/// Health check method /// Health check method
#[instrument(skip(infer))] #[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { async fn health<B: BatchType>(
infer: Extension<Infer<B>>) -> Result<(), (StatusCode, Json<ErrorResponse>
)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might // TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check. // be a bit too slow for a health check.
// What we should do instead is check if the gRPC channels are still healthy. // What we should do instead is check if the gRPC channels are still healthy.
@ -151,8 +155,8 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
seed, seed,
) )
)] )]
async fn generate( async fn generate<B: BatchType>(
infer: Extension<Infer>, infer: Extension<Infer<B>>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
@ -333,8 +337,8 @@ async fn generate(
seed, seed,
) )
)] )]
async fn generate_stream( async fn generate_stream<B: BatchType + 'static> (
infer: Extension<Infer>, infer: Extension<Infer<B>>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> ( ) -> (
HeaderMap, HeaderMap,
@ -494,6 +498,59 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render() prom_handle.render()
} }
struct BatchConfigValidator<B: BatchType> {
batch_type: PhantomData<B>
}
impl<B: BatchType> BatchConfigValidator<B> {
fn validate_batch_config(
&self,
max_total_tokens: usize,
max_batch_size: usize,
max_batch_weight: Option<usize>,
max_prefill_weight: Option<usize>,
) -> (usize, usize) {
let single_request_stats = <B as BatchType>::update_stats(
&B::Stats::default(), max_total_tokens, 0
);
let single_request_weight = <B as BatchType>::batch_weight(
&single_request_stats, 1
);
let weight_upper_bound = single_request_weight * max_batch_size;
let max_prefill_weight = if let Some(max_prefill_weight) = max_prefill_weight {
let single_request_prefill_weight = <B as BatchType>::prefill_weight(
&single_request_stats, 1
);
if max_prefill_weight < single_request_prefill_weight {
panic!("max_prefill_weight not large enough for max_total_tokens")
}
max_prefill_weight
} else {
0
};
let max_batch_weight = if let Some(mut max_batch_weight) = max_batch_weight {
if max_batch_weight < single_request_weight {
panic!("max_batch_weight not large enough for max_total_tokens")
}
if max_batch_weight > weight_upper_bound {
warn!(
"Reducing specified max_batch_weight ({}) to ({}) which is an \
upper bound based on max_total_tokens ({}) and max_batch_size ({})",
max_batch_weight, weight_upper_bound, max_total_tokens, max_batch_size
);
max_batch_weight = weight_upper_bound
}
max_batch_weight
} else {
weight_upper_bound
};
(max_batch_weight, max_prefill_weight)
}
}
/// Serving method /// Serving method
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
@ -513,6 +570,72 @@ pub async fn run(
validation_workers: usize, validation_workers: usize,
addr: SocketAddr, addr: SocketAddr,
allow_origin: Option<AllowOrigin>, allow_origin: Option<AllowOrigin>,
) {
//TODO get this from querying shard
let flash_attention = true;
if flash_attention {
do_run(
model_info,
compat_return_full_text,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
max_batch_weight,
max_prefill_weight,
max_waiting_tokens,
client,
tokenizer,
validation_workers,
addr,
allow_origin,
FlashBatch{},
).await
} else {
do_run(
model_info,
compat_return_full_text,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
max_batch_weight,
max_prefill_weight,
max_waiting_tokens,
client,
tokenizer,
validation_workers,
addr,
allow_origin,
PaddedBatch{},
).await
}
}
#[allow(clippy::too_many_arguments)]
async fn do_run<B: BatchType>(
model_info: ModelInfo,
compat_return_full_text: bool,
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
max_batch_size: usize,
max_batch_weight: Option<usize>,
max_prefill_weight: Option<usize>,
max_waiting_tokens: usize,
client: ShardedClient,
tokenizer: Option<Tokenizer>,
validation_workers: usize,
addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
_batch_type: B,
) { ) {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -554,14 +677,12 @@ pub async fn run(
)] )]
struct ApiDoc; struct ApiDoc;
// If max batch weight is not set, infer from max batch size and max seq length let batch_config_validator = BatchConfigValidator::<B>{batch_type: PhantomData};
let max_batch_weight = max_batch_weight
.unwrap_or(max_batch_size * max_total_tokens);
let max_prefill_weight = max_prefill_weight.unwrap_or_default();
if max_total_tokens > max_batch_weight { // If max batch weight is not set, infer from max batch size and max seq length
panic!("max_total_tokens cannot be greater than max_batch_weight"); let (max_batch_weight, max_prefill_weight) = batch_config_validator.validate_batch_config(
} max_total_tokens, max_batch_size, max_batch_weight, max_prefill_weight,
);
// Create state // Create state
let validation = Validation::new( let validation = Validation::new(
@ -580,6 +701,7 @@ pub async fn run(
max_prefill_weight, max_prefill_weight,
max_waiting_tokens, max_waiting_tokens,
max_concurrent_requests, max_concurrent_requests,
FlashBatch{}
); );
// Duration buckets // Duration buckets
@ -639,18 +761,18 @@ pub async fn run(
let app = Router::new() let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
// Base routes // Base routes
.route("/", post(compat_generate)) .route("/", post(compat_generate::<B>))
.route("/info", get(get_model_info)) .route("/info", get(get_model_info))
.route("/generate", post(generate)) .route("/generate", post(generate::<B>))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream::<B>))
// AWS Sagemaker route // AWS Sagemaker route
.route("/invocations", post(compat_generate)) .route("/invocations", post(compat_generate::<B>))
// Base Health route // Base Health route
.route("/health", get(health)) .route("/health", get(health::<B>))
// Inference API health route // Inference API health route
.route("/", get(health)) .route("/", get(health::<B>))
// AWS Sagemaker health route // AWS Sagemaker health route
.route("/ping", get(health)) .route("/ping", get(health::<B>))
// Prometheus metrics route // Prometheus metrics route
.route("/metrics", get(metrics)) .route("/metrics", get(metrics))
.layer(Extension(model_info)) .layer(Extension(model_info))