mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
abstract for padded batch case
This commit is contained in:
parent
47fb2fb986
commit
08aee68f79
77
Cargo.lock
generated
77
Cargo.lock
generated
@ -1438,6 +1438,82 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-iter",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.45"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-iter"
|
||||
version = "0.1.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.15.0"
|
||||
@ -2414,6 +2490,7 @@ dependencies = [
|
||||
"metrics",
|
||||
"metrics-exporter-prometheus",
|
||||
"nohash-hasher",
|
||||
"num",
|
||||
"opentelemetry",
|
||||
"opentelemetry-otlp",
|
||||
"rand",
|
||||
|
@ -24,6 +24,7 @@ futures = "0.3.26"
|
||||
metrics = "0.20.1"
|
||||
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
|
||||
nohash-hasher = "0.2.0"
|
||||
num = "0.4.0"
|
||||
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.11.0"
|
||||
rand = "0.8.5"
|
||||
|
@ -15,15 +15,15 @@ use thiserror::Error;
|
||||
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
use crate::queue::BatchingConfig;
|
||||
use crate::queue::{BatchingConfig, BatchType};
|
||||
|
||||
/// Inference struct
|
||||
#[derive(Clone)]
|
||||
pub struct Infer {
|
||||
pub(crate) struct Infer<B: BatchType> {
|
||||
/// Validation
|
||||
validation: Validation,
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
queue: Queue<B>,
|
||||
/// Shared state
|
||||
shared: Arc<Shared>,
|
||||
/// Inference limit
|
||||
@ -36,7 +36,7 @@ struct Shared {
|
||||
batching_task: Notify,
|
||||
}
|
||||
|
||||
impl Infer {
|
||||
impl<B: BatchType> Infer<B> {
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
validation: Validation,
|
||||
@ -45,13 +45,14 @@ impl Infer {
|
||||
max_prefill_weight: usize,
|
||||
max_waiting_tokens: usize,
|
||||
max_concurrent_requests: usize,
|
||||
batch_type: B,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new(BatchingConfig {
|
||||
size_limit: max_batch_size,
|
||||
weight_limit: max_batch_weight,
|
||||
prefill_weight_limit: max_prefill_weight,
|
||||
});
|
||||
}, batch_type);
|
||||
let shared = Arc::new(Shared {
|
||||
batching_task: Notify::new(),
|
||||
});
|
||||
@ -237,11 +238,11 @@ impl Infer {
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
async fn batching_task(
|
||||
async fn batching_task<B: BatchType>(
|
||||
mut client: ShardedClient,
|
||||
// max_batch_size: usize,
|
||||
max_waiting_tokens: usize,
|
||||
queue: Queue,
|
||||
queue: Queue<B>,
|
||||
shared: Arc<Shared>,
|
||||
) {
|
||||
// Infinite loop
|
||||
|
@ -1,10 +1,13 @@
|
||||
use std::cmp::max;
|
||||
use crate::infer::InferError;
|
||||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::collections::{BTreeSet, VecDeque};
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Add;
|
||||
use std::time::Duration;
|
||||
use num::integer::Roots;
|
||||
use text_generation_client::{Batch, Request};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Instant;
|
||||
@ -31,20 +34,22 @@ pub(crate) struct Entry {
|
||||
|
||||
/// Request Queue
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Queue {
|
||||
pub(crate) struct Queue<B: BatchType> {
|
||||
/// Channel to communicate with the background queue task
|
||||
queue_sender: flume::Sender<QueueCommand>,
|
||||
/// Just for type inference
|
||||
batch_type: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl Queue {
|
||||
pub(crate) fn new(config: BatchingConfig) -> Self {
|
||||
impl<B: BatchType> Queue<B> {
|
||||
pub(crate) fn new(config: BatchingConfig, batch_type: B) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(queue_task(queue_receiver, config));
|
||||
tokio::spawn(queue_task(queue_receiver, config, batch_type));
|
||||
|
||||
Self { queue_sender }
|
||||
Self { queue_sender, batch_type: PhantomData }
|
||||
}
|
||||
|
||||
/// Append an entry to the queue
|
||||
@ -81,8 +86,10 @@ impl Queue {
|
||||
}
|
||||
|
||||
// Background task responsible of the queue state
|
||||
async fn queue_task(receiver: flume::Receiver<QueueCommand>, config: BatchingConfig) {
|
||||
let mut state = State::new(config);
|
||||
async fn queue_task<B: BatchType>(
|
||||
receiver: flume::Receiver<QueueCommand>, config: BatchingConfig, batch_type: B
|
||||
) {
|
||||
let mut state = State::new(config, batch_type);
|
||||
|
||||
while let Ok(cmd) = receiver.recv_async().await {
|
||||
match cmd {
|
||||
@ -108,9 +115,10 @@ pub(crate) struct BatchingConfig {
|
||||
|
||||
/// Queue State
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
struct State<B: BatchType> {
|
||||
/// Batching configuration
|
||||
config: BatchingConfig,
|
||||
batch_type: PhantomData<B>,
|
||||
|
||||
/// Queue entries organized in a Vec
|
||||
entries: VecDeque<(u64, Entry)>,
|
||||
@ -149,10 +157,133 @@ const MAX_WAITING_DURATION: Duration = Duration::from_secs(1);
|
||||
/// ahead of larger ones in the queue
|
||||
const CUTOFF_DURATION: Duration = Duration::from_secs(1);
|
||||
|
||||
impl State {
|
||||
fn new(config: BatchingConfig) -> Self {
|
||||
pub(crate) trait BatchType: Send + Sync + Clone + 'static {
|
||||
type Stats: Default;
|
||||
|
||||
/// Update batch statistics with an additional request
|
||||
fn update_stats(stats: &Self::Stats, input_length: usize, output_length: usize) -> Self::Stats;
|
||||
/// Calculate batch weight given batch statistics
|
||||
fn batch_weight(stats: &Self::Stats, batch_size: usize) -> usize;
|
||||
/// Calculate prefill batch weight given prefill batch statistics
|
||||
fn prefill_weight(prefill_stats: &Self::Stats, batch_size: usize) -> usize;
|
||||
/// Indicate whether a hypothetical batch will exceed the combined weight limit
|
||||
fn exceeds_weight(
|
||||
tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize
|
||||
) -> bool;
|
||||
|
||||
/// Compute batch statistics given map of entries
|
||||
fn compute_stats(entries: &IntMap<u64, Entry>) -> Self::Stats {
|
||||
entries.iter().fold(
|
||||
Self::Stats::default(),
|
||||
|stats, (_, e)| Self::update_stats(
|
||||
&stats,
|
||||
e.request.truncate as usize,
|
||||
e.request.stopping_parameters.max_new_tokens as usize,
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Non-padded batch used in flash attention
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct FlashBatch {}
|
||||
|
||||
impl BatchType for FlashBatch {
|
||||
/// Keep track of total number of tokens in the batch
|
||||
type Stats = usize;
|
||||
|
||||
fn update_stats(
|
||||
total_tokens: &Self::Stats, input_length: usize, output_length: usize
|
||||
) -> Self::Stats {
|
||||
total_tokens + input_length + output_length
|
||||
}
|
||||
|
||||
fn batch_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
|
||||
*total_tokens
|
||||
}
|
||||
|
||||
fn prefill_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
|
||||
*total_tokens
|
||||
}
|
||||
|
||||
fn exceeds_weight(
|
||||
tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize
|
||||
) -> bool {
|
||||
let mut in_sum = 0;
|
||||
// Work backwards from longest projected entry
|
||||
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
|
||||
let this_ol = *ol;
|
||||
in_sum += *il;
|
||||
if this_ol <= current_output_len {
|
||||
// Check if we breach max space for this segment
|
||||
let token_count = in_sum + (bs + 1) * this_ol;
|
||||
if token_count > max_total_weight {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Regular rectangular padded
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct PaddedBatch {}
|
||||
|
||||
impl BatchType for PaddedBatch {
|
||||
/// Keep track of maximum input length, maximum output length
|
||||
type Stats = (usize, usize);
|
||||
|
||||
fn update_stats(
|
||||
max_in_out_lengths: &Self::Stats, input_length: usize, output_length: usize
|
||||
) -> Self::Stats {
|
||||
let (max_input_length, max_output_length) = max_in_out_lengths;
|
||||
(max(*max_input_length, input_length), max(*max_output_length, output_length))
|
||||
}
|
||||
|
||||
fn batch_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
|
||||
let (max_input_length, max_output_length) = max_in_out_lengths;
|
||||
let max_seq_len = max_input_length + max_output_length;
|
||||
// Memory requirement roughly propotionall to batch_size * seq_len^2
|
||||
batch_size * max_seq_len.pow(2)
|
||||
}
|
||||
|
||||
fn prefill_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
|
||||
// Empirically, prefill latency is proportional to batch_size * seq_len^(3/2)
|
||||
let (max_input_length, _) = max_in_out_lengths;
|
||||
batch_size * max_input_length.pow(3).sqrt()
|
||||
}
|
||||
|
||||
fn exceeds_weight(
|
||||
tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize
|
||||
) -> bool {
|
||||
let mut max_in = 0;
|
||||
let mut last_ol = 0;
|
||||
// Work backwards from longest projected entry
|
||||
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
|
||||
let this_ol = *ol;
|
||||
if this_ol != last_ol {
|
||||
max_in = max(max_in, *il);
|
||||
if this_ol <= current_output_len {
|
||||
// Check if we breach max space for this segment
|
||||
let seq_len = max_in + this_ol;
|
||||
if seq_len.pow(2) * (bs + 1) > max_total_weight {
|
||||
return true
|
||||
}
|
||||
}
|
||||
last_ol = this_ol;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<B: BatchType> State<B> {
|
||||
fn new(config: BatchingConfig, _batch_type: B) -> Self {
|
||||
Self {
|
||||
config,
|
||||
batch_type: PhantomData,
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
@ -226,11 +357,8 @@ impl State {
|
||||
let mut time_cutoff = None;
|
||||
let mut hit_prefill_weight_limit = false;
|
||||
|
||||
let mut total_token_count = existing_entries.iter().map(
|
||||
|(_, e)| e.request.stopping_parameters.max_new_tokens + e.request.truncate
|
||||
).sum::<u32>() as usize;
|
||||
|
||||
let mut prefill_size = 0;
|
||||
let mut batch_stats = <B as BatchType>::compute_stats(existing_entries);
|
||||
let mut prefill_stats = <B as BatchType>::compute_stats(&self.empty_map);
|
||||
// We first do a read-only pass over the queue to allow skipping over large entries
|
||||
// that don't fit in the current batch to reach smaller entries that do
|
||||
let mut queue_index = checked_up_to_index;
|
||||
@ -252,10 +380,14 @@ impl State {
|
||||
|
||||
let input_len = entry.request.truncate as usize;
|
||||
let output_len = entry.request.stopping_parameters.max_new_tokens as usize;
|
||||
let next_total_token_count = total_token_count + input_len + output_len;
|
||||
let next_stats = <B as BatchType>::update_stats(
|
||||
&batch_stats, input_len, output_len
|
||||
);
|
||||
|
||||
// Avoid more granular analysis if possible
|
||||
if next_total_token_count > config.weight_limit {
|
||||
if <B as BatchType>::batch_weight(
|
||||
&batch_stats, total_count + 1
|
||||
) > config.weight_limit {
|
||||
// We aren't sure whether this next request will fit, so populate
|
||||
// a btree with the current batch of requests, the set of
|
||||
// requests already evaluated, and this one, and perform more
|
||||
@ -283,21 +415,13 @@ impl State {
|
||||
tree.insert((output_len, input_len, entry_id));
|
||||
|
||||
// Perform analysis
|
||||
let mut in_sum = 0;
|
||||
// Work backwards from longest projected entry
|
||||
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
|
||||
let this_ol = *ol;
|
||||
in_sum += *il;
|
||||
if this_ol <= output_len {
|
||||
// Check if we breach max space for this segment
|
||||
let token_count = in_sum + (bs + 1) * this_ol;
|
||||
if token_count > config.weight_limit {
|
||||
// Remove our tuple from the set
|
||||
tree.remove(&(output_len, input_len, entry_id));
|
||||
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
|
||||
continue 'queue_loop
|
||||
}
|
||||
}
|
||||
if <B as BatchType>::exceeds_weight(
|
||||
tree, config.weight_limit, output_len,
|
||||
) {
|
||||
// Remove our tuple from the set
|
||||
tree.remove(&(output_len, input_len, entry_id));
|
||||
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
|
||||
continue 'queue_loop
|
||||
}
|
||||
} else if let Some(tree) = btree.as_mut() {
|
||||
// If we initialized the btree for a prior request, keep it updated
|
||||
@ -308,7 +432,13 @@ impl State {
|
||||
// Also check whether adding this request will make the batch of new requests
|
||||
// too expensive latency-wise to perform in a single forward-pass.
|
||||
if config.prefill_weight_limit > 0 {
|
||||
if prefill_size + input_len > config.prefill_weight_limit {
|
||||
let next_prefill_stats = <B as BatchType>::update_stats(
|
||||
&prefill_stats, input_len, 0
|
||||
);
|
||||
let prefill_weight = <B as BatchType>::prefill_weight(
|
||||
&next_prefill_stats, chosen_indices.len() + 1
|
||||
);
|
||||
if prefill_weight > config.prefill_weight_limit {
|
||||
if let Some(tree) = btree.as_mut() {
|
||||
// Remove our tuple from the set
|
||||
tree.remove(&(output_len, input_len, entry_id));
|
||||
@ -317,10 +447,10 @@ impl State {
|
||||
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
|
||||
continue
|
||||
}
|
||||
prefill_stats = next_prefill_stats;
|
||||
}
|
||||
|
||||
total_token_count = next_total_token_count;
|
||||
prefill_size += input_len;
|
||||
batch_stats = next_stats;
|
||||
|
||||
chosen_indices.push(queue_index - 1);
|
||||
total_count += 1;
|
||||
@ -349,7 +479,7 @@ impl State {
|
||||
// If this is to be added to an existing batch, ensure it meets urgency or size
|
||||
// requirements to avoid too frequent prefills
|
||||
if !self.next_entry_waiting_too_long() {
|
||||
if total_token_count < config.weight_limit / 2 {
|
||||
if <B as BatchType>::batch_weight(&batch_stats, total_count) < config.weight_limit / 2 {
|
||||
// Don't add this new batch yet because it's not large enough
|
||||
self.checked_request_count = checked_up_to_index;
|
||||
self.buffer_contents_insufficient = true;
|
||||
|
@ -17,15 +17,17 @@ use futures::stream::StreamExt;
|
||||
use futures::Stream;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use std::convert::Infallible;
|
||||
use std::marker::PhantomData;
|
||||
use std::net::SocketAddr;
|
||||
use text_generation_client::ShardedClient;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
use tokio::time::Instant;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
use tracing::{info_span, instrument, Instrument, warn};
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
use crate::queue::{BatchType, FlashBatch, PaddedBatch};
|
||||
|
||||
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
|
||||
#[utoipa::path(
|
||||
@ -46,9 +48,9 @@ use utoipa_swagger_ui::SwaggerUi;
|
||||
)
|
||||
)]
|
||||
#[instrument(skip(infer))]
|
||||
async fn compat_generate(
|
||||
async fn compat_generate<B: BatchType>(
|
||||
default_return_full_text: Extension<bool>,
|
||||
infer: Extension<Infer>,
|
||||
infer: Extension<Infer<B>>,
|
||||
req: Json<CompatGenerateRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let mut req = req.0;
|
||||
@ -92,7 +94,9 @@ async fn get_model_info(model_info: Extension<ModelInfo>) -> Json<Info> {
|
||||
|
||||
/// Health check method
|
||||
#[instrument(skip(infer))]
|
||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
async fn health<B: BatchType>(
|
||||
infer: Extension<Infer<B>>) -> Result<(), (StatusCode, Json<ErrorResponse>
|
||||
)> {
|
||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||
// be a bit too slow for a health check.
|
||||
// What we should do instead is check if the gRPC channels are still healthy.
|
||||
@ -151,8 +155,8 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||
seed,
|
||||
)
|
||||
)]
|
||||
async fn generate(
|
||||
infer: Extension<Infer>,
|
||||
async fn generate<B: BatchType>(
|
||||
infer: Extension<Infer<B>>,
|
||||
req: Json<GenerateRequest>,
|
||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
@ -333,8 +337,8 @@ async fn generate(
|
||||
seed,
|
||||
)
|
||||
)]
|
||||
async fn generate_stream(
|
||||
infer: Extension<Infer>,
|
||||
async fn generate_stream<B: BatchType + 'static> (
|
||||
infer: Extension<Infer<B>>,
|
||||
req: Json<GenerateRequest>,
|
||||
) -> (
|
||||
HeaderMap,
|
||||
@ -494,6 +498,59 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||
prom_handle.render()
|
||||
}
|
||||
|
||||
struct BatchConfigValidator<B: BatchType> {
|
||||
batch_type: PhantomData<B>
|
||||
}
|
||||
|
||||
impl<B: BatchType> BatchConfigValidator<B> {
|
||||
fn validate_batch_config(
|
||||
&self,
|
||||
max_total_tokens: usize,
|
||||
max_batch_size: usize,
|
||||
max_batch_weight: Option<usize>,
|
||||
max_prefill_weight: Option<usize>,
|
||||
) -> (usize, usize) {
|
||||
let single_request_stats = <B as BatchType>::update_stats(
|
||||
&B::Stats::default(), max_total_tokens, 0
|
||||
);
|
||||
let single_request_weight = <B as BatchType>::batch_weight(
|
||||
&single_request_stats, 1
|
||||
);
|
||||
let weight_upper_bound = single_request_weight * max_batch_size;
|
||||
|
||||
let max_prefill_weight = if let Some(max_prefill_weight) = max_prefill_weight {
|
||||
let single_request_prefill_weight = <B as BatchType>::prefill_weight(
|
||||
&single_request_stats, 1
|
||||
);
|
||||
if max_prefill_weight < single_request_prefill_weight {
|
||||
panic!("max_prefill_weight not large enough for max_total_tokens")
|
||||
}
|
||||
max_prefill_weight
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let max_batch_weight = if let Some(mut max_batch_weight) = max_batch_weight {
|
||||
if max_batch_weight < single_request_weight {
|
||||
panic!("max_batch_weight not large enough for max_total_tokens")
|
||||
}
|
||||
if max_batch_weight > weight_upper_bound {
|
||||
warn!(
|
||||
"Reducing specified max_batch_weight ({}) to ({}) which is an \
|
||||
upper bound based on max_total_tokens ({}) and max_batch_size ({})",
|
||||
max_batch_weight, weight_upper_bound, max_total_tokens, max_batch_size
|
||||
);
|
||||
max_batch_weight = weight_upper_bound
|
||||
}
|
||||
max_batch_weight
|
||||
} else {
|
||||
weight_upper_bound
|
||||
};
|
||||
|
||||
(max_batch_weight, max_prefill_weight)
|
||||
}
|
||||
}
|
||||
|
||||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
@ -513,6 +570,72 @@ pub async fn run(
|
||||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
) {
|
||||
//TODO get this from querying shard
|
||||
let flash_attention = true;
|
||||
|
||||
if flash_attention {
|
||||
do_run(
|
||||
model_info,
|
||||
compat_return_full_text,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_batch_weight,
|
||||
max_prefill_weight,
|
||||
max_waiting_tokens,
|
||||
client,
|
||||
tokenizer,
|
||||
validation_workers,
|
||||
addr,
|
||||
allow_origin,
|
||||
FlashBatch{},
|
||||
).await
|
||||
} else {
|
||||
do_run(
|
||||
model_info,
|
||||
compat_return_full_text,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_batch_weight,
|
||||
max_prefill_weight,
|
||||
max_waiting_tokens,
|
||||
client,
|
||||
tokenizer,
|
||||
validation_workers,
|
||||
addr,
|
||||
allow_origin,
|
||||
PaddedBatch{},
|
||||
).await
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn do_run<B: BatchType>(
|
||||
model_info: ModelInfo,
|
||||
compat_return_full_text: bool,
|
||||
max_concurrent_requests: usize,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
max_batch_size: usize,
|
||||
max_batch_weight: Option<usize>,
|
||||
max_prefill_weight: Option<usize>,
|
||||
max_waiting_tokens: usize,
|
||||
client: ShardedClient,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
_batch_type: B,
|
||||
) {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
@ -554,14 +677,12 @@ pub async fn run(
|
||||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
// If max batch weight is not set, infer from max batch size and max seq length
|
||||
let max_batch_weight = max_batch_weight
|
||||
.unwrap_or(max_batch_size * max_total_tokens);
|
||||
let max_prefill_weight = max_prefill_weight.unwrap_or_default();
|
||||
let batch_config_validator = BatchConfigValidator::<B>{batch_type: PhantomData};
|
||||
|
||||
if max_total_tokens > max_batch_weight {
|
||||
panic!("max_total_tokens cannot be greater than max_batch_weight");
|
||||
}
|
||||
// If max batch weight is not set, infer from max batch size and max seq length
|
||||
let (max_batch_weight, max_prefill_weight) = batch_config_validator.validate_batch_config(
|
||||
max_total_tokens, max_batch_size, max_batch_weight, max_prefill_weight,
|
||||
);
|
||||
|
||||
// Create state
|
||||
let validation = Validation::new(
|
||||
@ -580,6 +701,7 @@ pub async fn run(
|
||||
max_prefill_weight,
|
||||
max_waiting_tokens,
|
||||
max_concurrent_requests,
|
||||
FlashBatch{}
|
||||
);
|
||||
|
||||
// Duration buckets
|
||||
@ -639,18 +761,18 @@ pub async fn run(
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||
// Base routes
|
||||
.route("/", post(compat_generate))
|
||||
.route("/", post(compat_generate::<B>))
|
||||
.route("/info", get(get_model_info))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/generate", post(generate::<B>))
|
||||
.route("/generate_stream", post(generate_stream::<B>))
|
||||
// AWS Sagemaker route
|
||||
.route("/invocations", post(compat_generate))
|
||||
.route("/invocations", post(compat_generate::<B>))
|
||||
// Base Health route
|
||||
.route("/health", get(health))
|
||||
.route("/health", get(health::<B>))
|
||||
// Inference API health route
|
||||
.route("/", get(health))
|
||||
.route("/", get(health::<B>))
|
||||
// AWS Sagemaker health route
|
||||
.route("/ping", get(health))
|
||||
.route("/ping", get(health::<B>))
|
||||
// Prometheus metrics route
|
||||
.route("/metrics", get(metrics))
|
||||
.layer(Extension(model_info))
|
||||
|
Loading…
Reference in New Issue
Block a user