mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
abstract for padded batch case
This commit is contained in:
parent
47fb2fb986
commit
08aee68f79
77
Cargo.lock
generated
77
Cargo.lock
generated
@ -1438,6 +1438,82 @@ dependencies = [
|
|||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606"
|
||||||
|
dependencies = [
|
||||||
|
"num-bigint",
|
||||||
|
"num-complex",
|
||||||
|
"num-integer",
|
||||||
|
"num-iter",
|
||||||
|
"num-rational",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-bigint"
|
||||||
|
version = "0.4.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.4.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-integer"
|
||||||
|
version = "0.1.45"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-iter"
|
||||||
|
version = "0.1.43"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-rational"
|
||||||
|
version = "0.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-bigint",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-traits"
|
||||||
|
version = "0.2.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num_cpus"
|
name = "num_cpus"
|
||||||
version = "1.15.0"
|
version = "1.15.0"
|
||||||
@ -2414,6 +2490,7 @@ dependencies = [
|
|||||||
"metrics",
|
"metrics",
|
||||||
"metrics-exporter-prometheus",
|
"metrics-exporter-prometheus",
|
||||||
"nohash-hasher",
|
"nohash-hasher",
|
||||||
|
"num",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
"rand",
|
"rand",
|
||||||
|
@ -24,6 +24,7 @@ futures = "0.3.26"
|
|||||||
metrics = "0.20.1"
|
metrics = "0.20.1"
|
||||||
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
|
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
|
num = "0.4.0"
|
||||||
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.11.0"
|
opentelemetry-otlp = "0.11.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
|
@ -15,15 +15,15 @@ use thiserror::Error;
|
|||||||
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
use crate::queue::BatchingConfig;
|
use crate::queue::{BatchingConfig, BatchType};
|
||||||
|
|
||||||
/// Inference struct
|
/// Inference struct
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Infer {
|
pub(crate) struct Infer<B: BatchType> {
|
||||||
/// Validation
|
/// Validation
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
/// Request queue
|
/// Request queue
|
||||||
queue: Queue,
|
queue: Queue<B>,
|
||||||
/// Shared state
|
/// Shared state
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
/// Inference limit
|
/// Inference limit
|
||||||
@ -36,7 +36,7 @@ struct Shared {
|
|||||||
batching_task: Notify,
|
batching_task: Notify,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Infer {
|
impl<B: BatchType> Infer<B> {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
@ -45,13 +45,14 @@ impl Infer {
|
|||||||
max_prefill_weight: usize,
|
max_prefill_weight: usize,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
|
batch_type: B,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(BatchingConfig {
|
let queue = Queue::new(BatchingConfig {
|
||||||
size_limit: max_batch_size,
|
size_limit: max_batch_size,
|
||||||
weight_limit: max_batch_weight,
|
weight_limit: max_batch_weight,
|
||||||
prefill_weight_limit: max_prefill_weight,
|
prefill_weight_limit: max_prefill_weight,
|
||||||
});
|
}, batch_type);
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
@ -237,11 +238,11 @@ impl Infer {
|
|||||||
/// Will be launched in a background Tokio task
|
/// Will be launched in a background Tokio task
|
||||||
///
|
///
|
||||||
/// Batches requests and sends them to the inference server
|
/// Batches requests and sends them to the inference server
|
||||||
async fn batching_task(
|
async fn batching_task<B: BatchType>(
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
// max_batch_size: usize,
|
// max_batch_size: usize,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
queue: Queue,
|
queue: Queue<B>,
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
) {
|
) {
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
|
use std::cmp::max;
|
||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::infer::InferStreamResponse;
|
use crate::infer::InferStreamResponse;
|
||||||
use crate::validation::ValidGenerateRequest;
|
use crate::validation::ValidGenerateRequest;
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::collections::{BTreeSet, VecDeque};
|
use std::collections::{BTreeSet, VecDeque};
|
||||||
|
use std::marker::PhantomData;
|
||||||
use std::ops::Add;
|
use std::ops::Add;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use num::integer::Roots;
|
||||||
use text_generation_client::{Batch, Request};
|
use text_generation_client::{Batch, Request};
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
@ -31,20 +34,22 @@ pub(crate) struct Entry {
|
|||||||
|
|
||||||
/// Request Queue
|
/// Request Queue
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct Queue {
|
pub(crate) struct Queue<B: BatchType> {
|
||||||
/// Channel to communicate with the background queue task
|
/// Channel to communicate with the background queue task
|
||||||
queue_sender: flume::Sender<QueueCommand>,
|
queue_sender: flume::Sender<QueueCommand>,
|
||||||
|
/// Just for type inference
|
||||||
|
batch_type: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Queue {
|
impl<B: BatchType> Queue<B> {
|
||||||
pub(crate) fn new(config: BatchingConfig) -> Self {
|
pub(crate) fn new(config: BatchingConfig, batch_type: B) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||||
|
|
||||||
// Launch background queue task
|
// Launch background queue task
|
||||||
tokio::spawn(queue_task(queue_receiver, config));
|
tokio::spawn(queue_task(queue_receiver, config, batch_type));
|
||||||
|
|
||||||
Self { queue_sender }
|
Self { queue_sender, batch_type: PhantomData }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Append an entry to the queue
|
/// Append an entry to the queue
|
||||||
@ -81,8 +86,10 @@ impl Queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Background task responsible of the queue state
|
// Background task responsible of the queue state
|
||||||
async fn queue_task(receiver: flume::Receiver<QueueCommand>, config: BatchingConfig) {
|
async fn queue_task<B: BatchType>(
|
||||||
let mut state = State::new(config);
|
receiver: flume::Receiver<QueueCommand>, config: BatchingConfig, batch_type: B
|
||||||
|
) {
|
||||||
|
let mut state = State::new(config, batch_type);
|
||||||
|
|
||||||
while let Ok(cmd) = receiver.recv_async().await {
|
while let Ok(cmd) = receiver.recv_async().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
@ -108,9 +115,10 @@ pub(crate) struct BatchingConfig {
|
|||||||
|
|
||||||
/// Queue State
|
/// Queue State
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct State {
|
struct State<B: BatchType> {
|
||||||
/// Batching configuration
|
/// Batching configuration
|
||||||
config: BatchingConfig,
|
config: BatchingConfig,
|
||||||
|
batch_type: PhantomData<B>,
|
||||||
|
|
||||||
/// Queue entries organized in a Vec
|
/// Queue entries organized in a Vec
|
||||||
entries: VecDeque<(u64, Entry)>,
|
entries: VecDeque<(u64, Entry)>,
|
||||||
@ -149,10 +157,133 @@ const MAX_WAITING_DURATION: Duration = Duration::from_secs(1);
|
|||||||
/// ahead of larger ones in the queue
|
/// ahead of larger ones in the queue
|
||||||
const CUTOFF_DURATION: Duration = Duration::from_secs(1);
|
const CUTOFF_DURATION: Duration = Duration::from_secs(1);
|
||||||
|
|
||||||
impl State {
|
pub(crate) trait BatchType: Send + Sync + Clone + 'static {
|
||||||
fn new(config: BatchingConfig) -> Self {
|
type Stats: Default;
|
||||||
|
|
||||||
|
/// Update batch statistics with an additional request
|
||||||
|
fn update_stats(stats: &Self::Stats, input_length: usize, output_length: usize) -> Self::Stats;
|
||||||
|
/// Calculate batch weight given batch statistics
|
||||||
|
fn batch_weight(stats: &Self::Stats, batch_size: usize) -> usize;
|
||||||
|
/// Calculate prefill batch weight given prefill batch statistics
|
||||||
|
fn prefill_weight(prefill_stats: &Self::Stats, batch_size: usize) -> usize;
|
||||||
|
/// Indicate whether a hypothetical batch will exceed the combined weight limit
|
||||||
|
fn exceeds_weight(
|
||||||
|
tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize
|
||||||
|
) -> bool;
|
||||||
|
|
||||||
|
/// Compute batch statistics given map of entries
|
||||||
|
fn compute_stats(entries: &IntMap<u64, Entry>) -> Self::Stats {
|
||||||
|
entries.iter().fold(
|
||||||
|
Self::Stats::default(),
|
||||||
|
|stats, (_, e)| Self::update_stats(
|
||||||
|
&stats,
|
||||||
|
e.request.truncate as usize,
|
||||||
|
e.request.stopping_parameters.max_new_tokens as usize,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Non-padded batch used in flash attention
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub(crate) struct FlashBatch {}
|
||||||
|
|
||||||
|
impl BatchType for FlashBatch {
|
||||||
|
/// Keep track of total number of tokens in the batch
|
||||||
|
type Stats = usize;
|
||||||
|
|
||||||
|
fn update_stats(
|
||||||
|
total_tokens: &Self::Stats, input_length: usize, output_length: usize
|
||||||
|
) -> Self::Stats {
|
||||||
|
total_tokens + input_length + output_length
|
||||||
|
}
|
||||||
|
|
||||||
|
fn batch_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
|
||||||
|
*total_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prefill_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
|
||||||
|
*total_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
fn exceeds_weight(
|
||||||
|
tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize
|
||||||
|
) -> bool {
|
||||||
|
let mut in_sum = 0;
|
||||||
|
// Work backwards from longest projected entry
|
||||||
|
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
|
||||||
|
let this_ol = *ol;
|
||||||
|
in_sum += *il;
|
||||||
|
if this_ol <= current_output_len {
|
||||||
|
// Check if we breach max space for this segment
|
||||||
|
let token_count = in_sum + (bs + 1) * this_ol;
|
||||||
|
if token_count > max_total_weight {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Regular rectangular padded
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub(crate) struct PaddedBatch {}
|
||||||
|
|
||||||
|
impl BatchType for PaddedBatch {
|
||||||
|
/// Keep track of maximum input length, maximum output length
|
||||||
|
type Stats = (usize, usize);
|
||||||
|
|
||||||
|
fn update_stats(
|
||||||
|
max_in_out_lengths: &Self::Stats, input_length: usize, output_length: usize
|
||||||
|
) -> Self::Stats {
|
||||||
|
let (max_input_length, max_output_length) = max_in_out_lengths;
|
||||||
|
(max(*max_input_length, input_length), max(*max_output_length, output_length))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn batch_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
|
||||||
|
let (max_input_length, max_output_length) = max_in_out_lengths;
|
||||||
|
let max_seq_len = max_input_length + max_output_length;
|
||||||
|
// Memory requirement roughly propotionall to batch_size * seq_len^2
|
||||||
|
batch_size * max_seq_len.pow(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prefill_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
|
||||||
|
// Empirically, prefill latency is proportional to batch_size * seq_len^(3/2)
|
||||||
|
let (max_input_length, _) = max_in_out_lengths;
|
||||||
|
batch_size * max_input_length.pow(3).sqrt()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn exceeds_weight(
|
||||||
|
tree: &BTreeSet<(usize, usize, &u64)>, max_total_weight: usize, current_output_len: usize
|
||||||
|
) -> bool {
|
||||||
|
let mut max_in = 0;
|
||||||
|
let mut last_ol = 0;
|
||||||
|
// Work backwards from longest projected entry
|
||||||
|
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
|
||||||
|
let this_ol = *ol;
|
||||||
|
if this_ol != last_ol {
|
||||||
|
max_in = max(max_in, *il);
|
||||||
|
if this_ol <= current_output_len {
|
||||||
|
// Check if we breach max space for this segment
|
||||||
|
let seq_len = max_in + this_ol;
|
||||||
|
if seq_len.pow(2) * (bs + 1) > max_total_weight {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
last_ol = this_ol;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl<B: BatchType> State<B> {
|
||||||
|
fn new(config: BatchingConfig, _batch_type: B) -> Self {
|
||||||
Self {
|
Self {
|
||||||
config,
|
config,
|
||||||
|
batch_type: PhantomData,
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
@ -226,11 +357,8 @@ impl State {
|
|||||||
let mut time_cutoff = None;
|
let mut time_cutoff = None;
|
||||||
let mut hit_prefill_weight_limit = false;
|
let mut hit_prefill_weight_limit = false;
|
||||||
|
|
||||||
let mut total_token_count = existing_entries.iter().map(
|
let mut batch_stats = <B as BatchType>::compute_stats(existing_entries);
|
||||||
|(_, e)| e.request.stopping_parameters.max_new_tokens + e.request.truncate
|
let mut prefill_stats = <B as BatchType>::compute_stats(&self.empty_map);
|
||||||
).sum::<u32>() as usize;
|
|
||||||
|
|
||||||
let mut prefill_size = 0;
|
|
||||||
// We first do a read-only pass over the queue to allow skipping over large entries
|
// We first do a read-only pass over the queue to allow skipping over large entries
|
||||||
// that don't fit in the current batch to reach smaller entries that do
|
// that don't fit in the current batch to reach smaller entries that do
|
||||||
let mut queue_index = checked_up_to_index;
|
let mut queue_index = checked_up_to_index;
|
||||||
@ -252,10 +380,14 @@ impl State {
|
|||||||
|
|
||||||
let input_len = entry.request.truncate as usize;
|
let input_len = entry.request.truncate as usize;
|
||||||
let output_len = entry.request.stopping_parameters.max_new_tokens as usize;
|
let output_len = entry.request.stopping_parameters.max_new_tokens as usize;
|
||||||
let next_total_token_count = total_token_count + input_len + output_len;
|
let next_stats = <B as BatchType>::update_stats(
|
||||||
|
&batch_stats, input_len, output_len
|
||||||
|
);
|
||||||
|
|
||||||
// Avoid more granular analysis if possible
|
// Avoid more granular analysis if possible
|
||||||
if next_total_token_count > config.weight_limit {
|
if <B as BatchType>::batch_weight(
|
||||||
|
&batch_stats, total_count + 1
|
||||||
|
) > config.weight_limit {
|
||||||
// We aren't sure whether this next request will fit, so populate
|
// We aren't sure whether this next request will fit, so populate
|
||||||
// a btree with the current batch of requests, the set of
|
// a btree with the current batch of requests, the set of
|
||||||
// requests already evaluated, and this one, and perform more
|
// requests already evaluated, and this one, and perform more
|
||||||
@ -283,22 +415,14 @@ impl State {
|
|||||||
tree.insert((output_len, input_len, entry_id));
|
tree.insert((output_len, input_len, entry_id));
|
||||||
|
|
||||||
// Perform analysis
|
// Perform analysis
|
||||||
let mut in_sum = 0;
|
if <B as BatchType>::exceeds_weight(
|
||||||
// Work backwards from longest projected entry
|
tree, config.weight_limit, output_len,
|
||||||
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
|
) {
|
||||||
let this_ol = *ol;
|
|
||||||
in_sum += *il;
|
|
||||||
if this_ol <= output_len {
|
|
||||||
// Check if we breach max space for this segment
|
|
||||||
let token_count = in_sum + (bs + 1) * this_ol;
|
|
||||||
if token_count > config.weight_limit {
|
|
||||||
// Remove our tuple from the set
|
// Remove our tuple from the set
|
||||||
tree.remove(&(output_len, input_len, entry_id));
|
tree.remove(&(output_len, input_len, entry_id));
|
||||||
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
|
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
|
||||||
continue 'queue_loop
|
continue 'queue_loop
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if let Some(tree) = btree.as_mut() {
|
} else if let Some(tree) = btree.as_mut() {
|
||||||
// If we initialized the btree for a prior request, keep it updated
|
// If we initialized the btree for a prior request, keep it updated
|
||||||
tree.insert((output_len, input_len, entry_id));
|
tree.insert((output_len, input_len, entry_id));
|
||||||
@ -308,7 +432,13 @@ impl State {
|
|||||||
// Also check whether adding this request will make the batch of new requests
|
// Also check whether adding this request will make the batch of new requests
|
||||||
// too expensive latency-wise to perform in a single forward-pass.
|
// too expensive latency-wise to perform in a single forward-pass.
|
||||||
if config.prefill_weight_limit > 0 {
|
if config.prefill_weight_limit > 0 {
|
||||||
if prefill_size + input_len > config.prefill_weight_limit {
|
let next_prefill_stats = <B as BatchType>::update_stats(
|
||||||
|
&prefill_stats, input_len, 0
|
||||||
|
);
|
||||||
|
let prefill_weight = <B as BatchType>::prefill_weight(
|
||||||
|
&next_prefill_stats, chosen_indices.len() + 1
|
||||||
|
);
|
||||||
|
if prefill_weight > config.prefill_weight_limit {
|
||||||
if let Some(tree) = btree.as_mut() {
|
if let Some(tree) = btree.as_mut() {
|
||||||
// Remove our tuple from the set
|
// Remove our tuple from the set
|
||||||
tree.remove(&(output_len, input_len, entry_id));
|
tree.remove(&(output_len, input_len, entry_id));
|
||||||
@ -317,10 +447,10 @@ impl State {
|
|||||||
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
|
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
prefill_stats = next_prefill_stats;
|
||||||
}
|
}
|
||||||
|
|
||||||
total_token_count = next_total_token_count;
|
batch_stats = next_stats;
|
||||||
prefill_size += input_len;
|
|
||||||
|
|
||||||
chosen_indices.push(queue_index - 1);
|
chosen_indices.push(queue_index - 1);
|
||||||
total_count += 1;
|
total_count += 1;
|
||||||
@ -349,7 +479,7 @@ impl State {
|
|||||||
// If this is to be added to an existing batch, ensure it meets urgency or size
|
// If this is to be added to an existing batch, ensure it meets urgency or size
|
||||||
// requirements to avoid too frequent prefills
|
// requirements to avoid too frequent prefills
|
||||||
if !self.next_entry_waiting_too_long() {
|
if !self.next_entry_waiting_too_long() {
|
||||||
if total_token_count < config.weight_limit / 2 {
|
if <B as BatchType>::batch_weight(&batch_stats, total_count) < config.weight_limit / 2 {
|
||||||
// Don't add this new batch yet because it's not large enough
|
// Don't add this new batch yet because it's not large enough
|
||||||
self.checked_request_count = checked_up_to_index;
|
self.checked_request_count = checked_up_to_index;
|
||||||
self.buffer_contents_insufficient = true;
|
self.buffer_contents_insufficient = true;
|
||||||
|
@ -17,15 +17,17 @@ use futures::stream::StreamExt;
|
|||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
|
use std::marker::PhantomData;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||||
use tracing::{info_span, instrument, Instrument};
|
use tracing::{info_span, instrument, Instrument, warn};
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
use crate::queue::{BatchType, FlashBatch, PaddedBatch};
|
||||||
|
|
||||||
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
|
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
@ -46,9 +48,9 @@ use utoipa_swagger_ui::SwaggerUi;
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(infer))]
|
#[instrument(skip(infer))]
|
||||||
async fn compat_generate(
|
async fn compat_generate<B: BatchType>(
|
||||||
default_return_full_text: Extension<bool>,
|
default_return_full_text: Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer<B>>,
|
||||||
req: Json<CompatGenerateRequest>,
|
req: Json<CompatGenerateRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let mut req = req.0;
|
let mut req = req.0;
|
||||||
@ -92,7 +94,9 @@ async fn get_model_info(model_info: Extension<ModelInfo>) -> Json<Info> {
|
|||||||
|
|
||||||
/// Health check method
|
/// Health check method
|
||||||
#[instrument(skip(infer))]
|
#[instrument(skip(infer))]
|
||||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
async fn health<B: BatchType>(
|
||||||
|
infer: Extension<Infer<B>>) -> Result<(), (StatusCode, Json<ErrorResponse>
|
||||||
|
)> {
|
||||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||||
// be a bit too slow for a health check.
|
// be a bit too slow for a health check.
|
||||||
// What we should do instead is check if the gRPC channels are still healthy.
|
// What we should do instead is check if the gRPC channels are still healthy.
|
||||||
@ -151,8 +155,8 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||||||
seed,
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn generate(
|
async fn generate<B: BatchType>(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer<B>>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
@ -333,8 +337,8 @@ async fn generate(
|
|||||||
seed,
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn generate_stream(
|
async fn generate_stream<B: BatchType + 'static> (
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer<B>>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> (
|
) -> (
|
||||||
HeaderMap,
|
HeaderMap,
|
||||||
@ -494,6 +498,59 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
|||||||
prom_handle.render()
|
prom_handle.render()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct BatchConfigValidator<B: BatchType> {
|
||||||
|
batch_type: PhantomData<B>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: BatchType> BatchConfigValidator<B> {
|
||||||
|
fn validate_batch_config(
|
||||||
|
&self,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
max_batch_size: usize,
|
||||||
|
max_batch_weight: Option<usize>,
|
||||||
|
max_prefill_weight: Option<usize>,
|
||||||
|
) -> (usize, usize) {
|
||||||
|
let single_request_stats = <B as BatchType>::update_stats(
|
||||||
|
&B::Stats::default(), max_total_tokens, 0
|
||||||
|
);
|
||||||
|
let single_request_weight = <B as BatchType>::batch_weight(
|
||||||
|
&single_request_stats, 1
|
||||||
|
);
|
||||||
|
let weight_upper_bound = single_request_weight * max_batch_size;
|
||||||
|
|
||||||
|
let max_prefill_weight = if let Some(max_prefill_weight) = max_prefill_weight {
|
||||||
|
let single_request_prefill_weight = <B as BatchType>::prefill_weight(
|
||||||
|
&single_request_stats, 1
|
||||||
|
);
|
||||||
|
if max_prefill_weight < single_request_prefill_weight {
|
||||||
|
panic!("max_prefill_weight not large enough for max_total_tokens")
|
||||||
|
}
|
||||||
|
max_prefill_weight
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
let max_batch_weight = if let Some(mut max_batch_weight) = max_batch_weight {
|
||||||
|
if max_batch_weight < single_request_weight {
|
||||||
|
panic!("max_batch_weight not large enough for max_total_tokens")
|
||||||
|
}
|
||||||
|
if max_batch_weight > weight_upper_bound {
|
||||||
|
warn!(
|
||||||
|
"Reducing specified max_batch_weight ({}) to ({}) which is an \
|
||||||
|
upper bound based on max_total_tokens ({}) and max_batch_size ({})",
|
||||||
|
max_batch_weight, weight_upper_bound, max_total_tokens, max_batch_size
|
||||||
|
);
|
||||||
|
max_batch_weight = weight_upper_bound
|
||||||
|
}
|
||||||
|
max_batch_weight
|
||||||
|
} else {
|
||||||
|
weight_upper_bound
|
||||||
|
};
|
||||||
|
|
||||||
|
(max_batch_weight, max_prefill_weight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Serving method
|
/// Serving method
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
@ -513,6 +570,72 @@ pub async fn run(
|
|||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
allow_origin: Option<AllowOrigin>,
|
allow_origin: Option<AllowOrigin>,
|
||||||
|
) {
|
||||||
|
//TODO get this from querying shard
|
||||||
|
let flash_attention = true;
|
||||||
|
|
||||||
|
if flash_attention {
|
||||||
|
do_run(
|
||||||
|
model_info,
|
||||||
|
compat_return_full_text,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
max_batch_weight,
|
||||||
|
max_prefill_weight,
|
||||||
|
max_waiting_tokens,
|
||||||
|
client,
|
||||||
|
tokenizer,
|
||||||
|
validation_workers,
|
||||||
|
addr,
|
||||||
|
allow_origin,
|
||||||
|
FlashBatch{},
|
||||||
|
).await
|
||||||
|
} else {
|
||||||
|
do_run(
|
||||||
|
model_info,
|
||||||
|
compat_return_full_text,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
max_batch_weight,
|
||||||
|
max_prefill_weight,
|
||||||
|
max_waiting_tokens,
|
||||||
|
client,
|
||||||
|
tokenizer,
|
||||||
|
validation_workers,
|
||||||
|
addr,
|
||||||
|
allow_origin,
|
||||||
|
PaddedBatch{},
|
||||||
|
).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
async fn do_run<B: BatchType>(
|
||||||
|
model_info: ModelInfo,
|
||||||
|
compat_return_full_text: bool,
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
max_best_of: usize,
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
max_input_length: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
max_batch_size: usize,
|
||||||
|
max_batch_weight: Option<usize>,
|
||||||
|
max_prefill_weight: Option<usize>,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
client: ShardedClient,
|
||||||
|
tokenizer: Option<Tokenizer>,
|
||||||
|
validation_workers: usize,
|
||||||
|
addr: SocketAddr,
|
||||||
|
allow_origin: Option<AllowOrigin>,
|
||||||
|
_batch_type: B,
|
||||||
) {
|
) {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
@ -554,14 +677,12 @@ pub async fn run(
|
|||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
// If max batch weight is not set, infer from max batch size and max seq length
|
let batch_config_validator = BatchConfigValidator::<B>{batch_type: PhantomData};
|
||||||
let max_batch_weight = max_batch_weight
|
|
||||||
.unwrap_or(max_batch_size * max_total_tokens);
|
|
||||||
let max_prefill_weight = max_prefill_weight.unwrap_or_default();
|
|
||||||
|
|
||||||
if max_total_tokens > max_batch_weight {
|
// If max batch weight is not set, infer from max batch size and max seq length
|
||||||
panic!("max_total_tokens cannot be greater than max_batch_weight");
|
let (max_batch_weight, max_prefill_weight) = batch_config_validator.validate_batch_config(
|
||||||
}
|
max_total_tokens, max_batch_size, max_batch_weight, max_prefill_weight,
|
||||||
|
);
|
||||||
|
|
||||||
// Create state
|
// Create state
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
@ -580,6 +701,7 @@ pub async fn run(
|
|||||||
max_prefill_weight,
|
max_prefill_weight,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
FlashBatch{}
|
||||||
);
|
);
|
||||||
|
|
||||||
// Duration buckets
|
// Duration buckets
|
||||||
@ -639,18 +761,18 @@ pub async fn run(
|
|||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||||
// Base routes
|
// Base routes
|
||||||
.route("/", post(compat_generate))
|
.route("/", post(compat_generate::<B>))
|
||||||
.route("/info", get(get_model_info))
|
.route("/info", get(get_model_info))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate::<B>))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream::<B>))
|
||||||
// AWS Sagemaker route
|
// AWS Sagemaker route
|
||||||
.route("/invocations", post(compat_generate))
|
.route("/invocations", post(compat_generate::<B>))
|
||||||
// Base Health route
|
// Base Health route
|
||||||
.route("/health", get(health))
|
.route("/health", get(health::<B>))
|
||||||
// Inference API health route
|
// Inference API health route
|
||||||
.route("/", get(health))
|
.route("/", get(health::<B>))
|
||||||
// AWS Sagemaker health route
|
// AWS Sagemaker health route
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health::<B>))
|
||||||
// Prometheus metrics route
|
// Prometheus metrics route
|
||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
.layer(Extension(model_info))
|
.layer(Extension(model_info))
|
||||||
|
Loading…
Reference in New Issue
Block a user