mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
feat(router): dynamic batch sizing
This commit is contained in:
parent
709d8936f6
commit
ba1aae3e78
@ -41,6 +41,10 @@ struct Args {
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "32", long, env)]
|
||||
max_batch_size: usize,
|
||||
#[clap(default_value = None, long, env)]
|
||||
max_batch_weight: Option<usize>,
|
||||
#[clap(default_value = None, long, env)]
|
||||
max_prefill_weight: Option<usize>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
@ -93,6 +97,8 @@ fn main() -> ExitCode {
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_batch_weight,
|
||||
max_prefill_weight,
|
||||
max_waiting_tokens,
|
||||
port,
|
||||
shard_uds_path,
|
||||
@ -392,6 +398,16 @@ fn main() -> ExitCode {
|
||||
model_id,
|
||||
];
|
||||
|
||||
if let Some(max_batch_weight) = max_batch_weight {
|
||||
argv.push("--max-batch-weight".to_string());
|
||||
argv.push(max_batch_weight.to_string())
|
||||
}
|
||||
|
||||
if let Some(max_prefill_weight) = max_prefill_weight {
|
||||
argv.push("--max-batch-weight".to_string());
|
||||
argv.push(max_prefill_weight.to_string())
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(ref revision) = revision {
|
||||
argv.push("--revision".to_string());
|
||||
|
@ -15,6 +15,7 @@ use thiserror::Error;
|
||||
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
use crate::queue::BatchingConfig;
|
||||
|
||||
/// Inference struct
|
||||
#[derive(Clone)]
|
||||
@ -40,11 +41,17 @@ impl Infer {
|
||||
client: ShardedClient,
|
||||
validation: Validation,
|
||||
max_batch_size: usize,
|
||||
max_batch_weight: usize,
|
||||
max_prefill_weight: usize,
|
||||
max_waiting_tokens: usize,
|
||||
max_concurrent_requests: usize,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new();
|
||||
let queue = Queue::new(BatchingConfig {
|
||||
size_limit: max_batch_size,
|
||||
weight_limit: max_batch_weight,
|
||||
prefill_weight_limit: max_prefill_weight,
|
||||
});
|
||||
let shared = Arc::new(Shared {
|
||||
batching_task: Notify::new(),
|
||||
});
|
||||
@ -52,7 +59,6 @@ impl Infer {
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client,
|
||||
max_batch_size,
|
||||
max_waiting_tokens,
|
||||
queue.clone(),
|
||||
shared.clone(),
|
||||
@ -105,6 +111,7 @@ impl Infer {
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
request: valid_request,
|
||||
generated_tokens: 0,
|
||||
response_tx,
|
||||
span: Span::current(),
|
||||
temp_span: None,
|
||||
@ -232,18 +239,11 @@ impl Infer {
|
||||
/// Batches requests and sends them to the inference server
|
||||
async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
max_batch_size: usize,
|
||||
// max_batch_size: usize,
|
||||
max_waiting_tokens: usize,
|
||||
queue: Queue,
|
||||
shared: Arc<Shared>,
|
||||
) {
|
||||
// Minimum batch size after which we try to add more requests
|
||||
let limit_min_batch_size = if max_batch_size > 1 {
|
||||
(max_batch_size / 2) as u32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Infer struct
|
||||
@ -252,8 +252,8 @@ async fn batching_task(
|
||||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
while let (_, Some((mut entries, batch, span))) = queue.next_batch(None).await {
|
||||
let (mut cached_batch, mut some_completed) = prefill(&mut client, batch, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
@ -266,21 +266,16 @@ async fn batching_task(
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
||||
|
||||
// If the current batch is too small, we try to add more requests to it
|
||||
if batch_size <= limit_min_batch_size {
|
||||
let min_size = match waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
_ if waiting_tokens >= max_waiting_tokens => None,
|
||||
// Minimum size criteria
|
||||
_ => Some(limit_min_batch_size as usize),
|
||||
};
|
||||
// Try to extend batch if its size reduced or enough tokens have elapsed since last one
|
||||
if some_completed || waiting_tokens >= max_waiting_tokens {
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_batch_size - batch_size as usize)
|
||||
.await
|
||||
{
|
||||
// Try to get a new batch - ownership of entries passed in and out
|
||||
let (
|
||||
existing_entries, new_entries
|
||||
) = queue.next_batch(Some(entries)).await;
|
||||
entries = existing_entries.unwrap();
|
||||
|
||||
if let Some((mut new_entries, new_batch, span)) = new_entries {
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
@ -293,7 +288,7 @@ async fn batching_task(
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
let (new_cached_batch, _) = prefill(&mut client, new_batch, &mut new_entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
@ -319,7 +314,7 @@ async fn batching_task(
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
(cached_batch, some_completed) = decode(&mut client, batches, &mut entries)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
@ -334,14 +329,14 @@ async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<Batch> {
|
||||
) -> (Option<Batch>, bool) {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
filter_send_generations(generations, entries);
|
||||
let some_completed = filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = match next_batch {
|
||||
@ -360,14 +355,14 @@ async fn prefill(
|
||||
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||
next_batch
|
||||
(next_batch, some_completed)
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||
None
|
||||
(None, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -377,14 +372,14 @@ async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<Batch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<Batch> {
|
||||
) -> (Option<Batch>, bool) {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
filter_send_generations(generations, entries);
|
||||
let some_completed = filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = match next_batch {
|
||||
@ -403,7 +398,7 @@ async fn decode(
|
||||
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||
next_batch
|
||||
(next_batch, some_completed)
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
@ -412,7 +407,7 @@ async fn decode(
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
||||
None
|
||||
(None, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -431,14 +426,16 @@ fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch>
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
/// Return true if any requests completed
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) -> bool {
|
||||
let mut some_stopped = false;
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let entry = entries
|
||||
.get(&id)
|
||||
.get_mut(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
@ -451,9 +448,14 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
||||
err
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
some_stopped = true;
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
} else {
|
||||
// Increment generated token count
|
||||
entry.generated_tokens += 1;
|
||||
}
|
||||
});
|
||||
return some_stopped;
|
||||
}
|
||||
|
||||
/// Send responses through the `entry` response channel
|
||||
|
@ -33,6 +33,10 @@ struct Args {
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "32", long, env)]
|
||||
max_batch_size: usize,
|
||||
#[clap(default_value = None, long, env)]
|
||||
max_batch_weight: Option<usize>,
|
||||
#[clap(default_value = None, long, env)]
|
||||
max_prefill_weight: Option<usize>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
@ -64,6 +68,8 @@ fn main() -> Result<(), std::io::Error> {
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_batch_weight,
|
||||
max_prefill_weight,
|
||||
max_waiting_tokens,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
@ -169,6 +175,8 @@ fn main() -> Result<(), std::io::Error> {
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_batch_weight,
|
||||
max_prefill_weight,
|
||||
max_waiting_tokens,
|
||||
sharded_client,
|
||||
tokenizer,
|
||||
|
@ -2,8 +2,9 @@ use crate::infer::InferError;
|
||||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use std::collections::VecDeque;
|
||||
use std::collections::{BTreeSet, VecDeque};
|
||||
use std::ops::Add;
|
||||
use std::time::Duration;
|
||||
use text_generation_client::{Batch, Request};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Instant;
|
||||
@ -14,6 +15,8 @@ use tracing::{info_span, instrument, Span};
|
||||
pub(crate) struct Entry {
|
||||
/// Request
|
||||
pub request: ValidGenerateRequest,
|
||||
/// Count of tokens generated so far
|
||||
pub generated_tokens: usize,
|
||||
/// Response sender to communicate between the Infer struct and the batching_task
|
||||
pub response_tx: flume::Sender<Result<InferStreamResponse, InferError>>,
|
||||
/// Span that will live as long as entry
|
||||
@ -34,12 +37,12 @@ pub(crate) struct Queue {
|
||||
}
|
||||
|
||||
impl Queue {
|
||||
pub(crate) fn new() -> Self {
|
||||
pub(crate) fn new(config: BatchingConfig) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(queue_task(queue_receiver));
|
||||
tokio::spawn(queue_task(queue_receiver, config));
|
||||
|
||||
Self { queue_sender }
|
||||
}
|
||||
@ -54,21 +57,19 @@ impl Queue {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
// Get the next batch - existing batch is returned unchanged
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn next_batch(
|
||||
&self,
|
||||
min_size: Option<usize>,
|
||||
max_size: usize,
|
||||
) -> Option<NextBatch> {
|
||||
entries: Option<ExistingBatch>,
|
||||
) -> (Option<ExistingBatch>, Option<NextBatch>) {
|
||||
// Create response channel
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
// Send next batch command to the background task managing the state
|
||||
// Unwrap is safe here
|
||||
self.queue_sender
|
||||
.send(QueueCommand::NextBatch {
|
||||
min_size,
|
||||
max_size,
|
||||
entries,
|
||||
response_sender,
|
||||
span: Span::current(),
|
||||
})
|
||||
@ -80,28 +81,37 @@ impl Queue {
|
||||
}
|
||||
|
||||
// Background task responsible of the queue state
|
||||
async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
|
||||
let mut state = State::new();
|
||||
async fn queue_task(receiver: flume::Receiver<QueueCommand>, config: BatchingConfig) {
|
||||
let mut state = State::new(config);
|
||||
|
||||
while let Ok(cmd) = receiver.recv_async().await {
|
||||
match cmd {
|
||||
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
max_size,
|
||||
entries,
|
||||
response_sender,
|
||||
span,
|
||||
} => span.in_scope(|| {
|
||||
let next_batch = state.next_batch(min_size, max_size);
|
||||
response_sender.send(next_batch).unwrap_or(());
|
||||
let response = state.next_batch(entries);
|
||||
response_sender.send(response).unwrap_or(());
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct BatchingConfig {
|
||||
pub(crate) size_limit: usize,
|
||||
pub(crate) weight_limit: usize,
|
||||
pub(crate) prefill_weight_limit: usize,
|
||||
}
|
||||
|
||||
/// Queue State
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
/// Batching configuration
|
||||
config: BatchingConfig,
|
||||
|
||||
/// Queue entries organized in a Vec
|
||||
entries: VecDeque<(u64, Entry)>,
|
||||
|
||||
@ -110,14 +120,46 @@ struct State {
|
||||
|
||||
/// Id of the next batch
|
||||
next_batch_id: u64,
|
||||
|
||||
// Remembered size of the last batch, used to determine
|
||||
// when entries have completed between calls to the
|
||||
// next_batch function
|
||||
last_seen_batch_size: usize,
|
||||
|
||||
// Index in the queue up to which entries have been
|
||||
// checked to see if they can fit into the current batch.
|
||||
// Reset to zero when any existing entries complete
|
||||
checked_request_count: usize,
|
||||
|
||||
/// true if it's known that the current size of the
|
||||
/// requests in the queue is too small to prefill an add-on batch
|
||||
buffer_contents_insufficient: bool,
|
||||
|
||||
/// Just a constant empty map to reuse
|
||||
empty_map: ExistingBatch,
|
||||
}
|
||||
|
||||
// Could also make these configurable
|
||||
|
||||
/// Longest that requests can be waiting before we ignore the minimum
|
||||
/// size requirement when adding to a new batch
|
||||
const MAX_WAITING_DURATION: Duration = Duration::from_secs(1);
|
||||
|
||||
/// Maximum difference in arrival time that smaller requests can jump
|
||||
/// ahead of larger ones in the queue
|
||||
const CUTOFF_DURATION: Duration = Duration::from_secs(1);
|
||||
|
||||
impl State {
|
||||
fn new() -> Self {
|
||||
fn new(config: BatchingConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
last_seen_batch_size: 0,
|
||||
checked_request_count: 0,
|
||||
buffer_contents_insufficient: false,
|
||||
empty_map: IntMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,70 +176,221 @@ impl State {
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
fn next_batch(&mut self, min_size: Option<usize>, max_size: usize) -> Option<NextBatch> {
|
||||
if self.entries.is_empty() {
|
||||
return None;
|
||||
fn next_batch(
|
||||
&mut self, existing_entries_opt: Option<ExistingBatch>,
|
||||
) -> (Option<ExistingBatch>, Option<NextBatch>) {
|
||||
|
||||
// Use ref to empty map in None case to simplify subsequent logic
|
||||
let existing_entries = existing_entries_opt.as_ref().unwrap_or(&self.empty_map);
|
||||
|
||||
let config = &self.config;
|
||||
let mut total_count = existing_entries.len();
|
||||
if total_count >= config.size_limit {
|
||||
// We are already at max batch size
|
||||
return (existing_entries_opt, None)
|
||||
}
|
||||
|
||||
// Check if we have enough entries
|
||||
if let Some(min_size) = min_size {
|
||||
if self.entries.len() < min_size {
|
||||
return None;
|
||||
if total_count != self.last_seen_batch_size {
|
||||
// Reset the count of checked requests if any completed since last check
|
||||
self.checked_request_count = 0;
|
||||
self.last_seen_batch_size = total_count
|
||||
}
|
||||
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
let queue_len_before = self.entries.len();
|
||||
self.entries.retain_mut(|(_, entry)| !entry.response_tx.is_disconnected());
|
||||
if queue_len_before != self.entries.len() {
|
||||
// Reset the count of checked requests if any in the queue were cancelled since last check
|
||||
self.checked_request_count = 0;
|
||||
}
|
||||
|
||||
// This will generally be zero, but if no requests have been completed
|
||||
// since last time, we don't need to reconsider those already checked
|
||||
let mut checked_up_to_index = self.checked_request_count;
|
||||
|
||||
if !existing_entries.is_empty() {
|
||||
// If we don't have any new requests in the buffer to check
|
||||
if self.entries.len() <= checked_up_to_index ||
|
||||
// Or the current buffer isn't large enough to satisfy the min prefill requirement
|
||||
self.buffer_contents_insufficient && !self.next_entry_waiting_too_long() {
|
||||
return (existing_entries_opt, None)
|
||||
}
|
||||
}
|
||||
|
||||
let max_batch_size = min(self.entries.len(), max_size);
|
||||
// Indices into buffer of entries chosen to add to next batch or remove due to expiry
|
||||
let mut chosen_indices = vec![];
|
||||
// Indices to drop due to client cancellation
|
||||
let mut indices_to_drop = vec![];
|
||||
let mut btree = None;
|
||||
let mut time_cutoff = None;
|
||||
let mut hit_prefill_weight_limit = false;
|
||||
|
||||
let mut total_token_count = existing_entries.iter().map(
|
||||
|(_, e)| e.request.stopping_parameters.max_new_tokens + e.request.truncate
|
||||
).sum::<u32>() as usize;
|
||||
|
||||
let mut prefill_size = 0;
|
||||
// We first do a read-only pass over the queue to allow skipping over large entries
|
||||
// that don't fit in the current batch to reach smaller entries that do
|
||||
let mut queue_index = checked_up_to_index;
|
||||
'queue_loop: for (entry_id, entry) in self.entries.range(queue_index..) {
|
||||
if matches!(time_cutoff, Some(t) if entry.queue_time > t) {
|
||||
break
|
||||
}
|
||||
queue_index += 1;
|
||||
if entry.response_tx.is_disconnected() {
|
||||
// Eject cancelled entry from queue
|
||||
indices_to_drop.push(queue_index);
|
||||
continue
|
||||
}
|
||||
// This is the index into the queue after cancelled entries
|
||||
// have been pruned
|
||||
checked_up_to_index += 1;
|
||||
|
||||
let input_len = entry.request.truncate as usize;
|
||||
let output_len = entry.request.stopping_parameters.max_new_tokens as usize;
|
||||
let next_total_token_count = total_token_count + input_len + output_len;
|
||||
|
||||
// Avoid more granular analysis if possible
|
||||
if next_total_token_count > config.weight_limit {
|
||||
// We aren't sure whether this next request will fit, so populate
|
||||
// a btree with the current batch of requests, the set of
|
||||
// requests already evaluated, and this one, and perform more
|
||||
// granular analysis to verify that the batch shape won't exceed
|
||||
// the limit at any point
|
||||
|
||||
// Allocate btree the first time it's required
|
||||
let tree = btree.get_or_insert_with(|| {
|
||||
let mut t = Box::new(BTreeSet::new());
|
||||
// Populate with records corresponding to all existing and pending entries
|
||||
let pending = chosen_indices.iter()
|
||||
.map(|i| self.entries.get(*i).unwrap())
|
||||
.map(|(eid, e)| (eid, e));
|
||||
for (eid, e) in existing_entries.iter().chain(pending) {
|
||||
let generated_count = e.generated_tokens;
|
||||
t.insert((
|
||||
e.request.stopping_parameters.max_new_tokens as usize - generated_count,
|
||||
e.request.truncate as usize + e.generated_tokens,
|
||||
eid,
|
||||
));
|
||||
}
|
||||
t
|
||||
});
|
||||
// Add the current entry
|
||||
tree.insert((output_len, input_len, entry_id));
|
||||
|
||||
// Perform analysis
|
||||
let mut in_sum = 0;
|
||||
// Work backwards from longest projected entry
|
||||
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
|
||||
let this_ol = *ol;
|
||||
in_sum += *il;
|
||||
if this_ol <= output_len {
|
||||
// Check if we breach max space for this segment
|
||||
let token_count = in_sum + (bs + 1) * this_ol;
|
||||
if token_count > config.weight_limit {
|
||||
// Remove our tuple from the set
|
||||
tree.remove(&(output_len, input_len, entry_id));
|
||||
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
|
||||
continue 'queue_loop
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if let Some(tree) = btree.as_mut() {
|
||||
// If we initialized the btree for a prior request, keep it updated
|
||||
tree.insert((output_len, input_len, entry_id));
|
||||
}
|
||||
// Here, we can add this request to the batch without breaching memory limit
|
||||
|
||||
if config.prefill_weight_limit > 0 {
|
||||
// Also check whether adding this request will make the batch of new requests
|
||||
// too expensive latency-wise to perform in a single forward-pass.
|
||||
if prefill_size + input_len > config.prefill_weight_limit {
|
||||
if let Some(tree) = btree.as_mut() {
|
||||
// Remove our tuple from the set
|
||||
tree.remove(&(output_len, input_len, entry_id));
|
||||
hit_prefill_weight_limit = true;
|
||||
}
|
||||
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
total_token_count = next_total_token_count;
|
||||
prefill_size += input_len;
|
||||
|
||||
chosen_indices.push(queue_index - 1);
|
||||
total_count += 1;
|
||||
if total_count >= config.size_limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Drop any cancelled requests
|
||||
if !indices_to_drop.is_empty() {
|
||||
indices_to_drop.iter().for_each(|i| {
|
||||
self.entries.remove(*i);
|
||||
});
|
||||
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
|
||||
}
|
||||
|
||||
let next_batch_size = chosen_indices.len();
|
||||
if next_batch_size == 0 {
|
||||
// This gets reset to zero when any requests in the existing batch are removed
|
||||
self.checked_request_count = checked_up_to_index;
|
||||
return (existing_entries_opt, None)
|
||||
}
|
||||
self.checked_request_count = 0;
|
||||
|
||||
if !hit_prefill_weight_limit && !existing_entries.is_empty() {
|
||||
// If this is to be added to an existing batch, ensure it meets urgency or size
|
||||
// requirements to avoid too frequent prefills
|
||||
if !self.next_entry_waiting_too_long() {
|
||||
if total_token_count < config.weight_limit / 2 {
|
||||
// Don't add this new batch yet because it's not large enough
|
||||
self.checked_request_count = checked_up_to_index;
|
||||
self.buffer_contents_insufficient = true;
|
||||
return (existing_entries_opt, None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||
next_batch_span.follows_from(&Span::current());
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(max_batch_size);
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
|
||||
|
||||
// Iterate on buffer
|
||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_disconnected() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
continue;
|
||||
}
|
||||
IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default());
|
||||
|
||||
let some_now = Some(Instant::now());
|
||||
let batch_requests = chosen_indices.iter().enumerate().map(|(i, index)| {
|
||||
let (id, mut entry) = self.entries.remove(index - i).expect("bug");
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
let entry_batch_span =
|
||||
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
|
||||
batch_requests.push(Request {
|
||||
let request = Request {
|
||||
id,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
truncate: entry.request.truncate,
|
||||
parameters: Some(entry.request.parameters.clone()),
|
||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||
});
|
||||
};
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
entry.batch_time = some_now;
|
||||
// Insert in batch_entries IntMap
|
||||
batch_entries.insert(id, entry);
|
||||
|
||||
if batch_requests.len() == max_batch_size {
|
||||
// We have enough requests in the batch
|
||||
break;
|
||||
}
|
||||
}
|
||||
request
|
||||
}).collect::<Vec<Request>>();
|
||||
|
||||
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
|
||||
|
||||
// Maybe all entries were dropped because their channel were closed
|
||||
if batch_requests.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Final batch size once we dropped entries
|
||||
let size = batch_requests.len() as u32;
|
||||
next_batch_span.record("batch_size", size);
|
||||
@ -209,224 +402,30 @@ impl State {
|
||||
};
|
||||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
self.buffer_contents_insufficient = false;
|
||||
|
||||
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
||||
Some((batch_entries, batch, next_batch_span))
|
||||
(existing_entries_opt, Some((batch_entries, batch, next_batch_span)))
|
||||
}
|
||||
|
||||
/// Returns true if the entry at the front of the queue has been waiting for longer
|
||||
/// than MAX_WAITING_DURATION
|
||||
fn next_entry_waiting_too_long(&self) -> bool {
|
||||
matches!(
|
||||
self.entries.front(), Some((_, e)) if e.queue_time.elapsed() > MAX_WAITING_DURATION
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
type ExistingBatch = IntMap<u64, Entry>;
|
||||
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
|
||||
|
||||
#[derive(Debug)]
|
||||
enum QueueCommand {
|
||||
Append(Entry, Span),
|
||||
NextBatch {
|
||||
min_size: Option<usize>,
|
||||
max_size: usize,
|
||||
response_sender: oneshot::Sender<Option<NextBatch>>,
|
||||
entries: Option<ExistingBatch>,
|
||||
response_sender: oneshot::Sender<(Option<ExistingBatch>, Option<NextBatch>)>,
|
||||
span: Span,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_entry() -> (
|
||||
Entry,
|
||||
flume::Receiver<Result<InferStreamResponse, InferError>>,
|
||||
) {
|
||||
let (response_tx, receiver_tx) = flume::unbounded();
|
||||
|
||||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: "".to_string(),
|
||||
truncate: 0,
|
||||
parameters: NextTokenChooserParameters {
|
||||
temperature: 0.0,
|
||||
top_k: 0,
|
||||
top_p: 0.0,
|
||||
typical_p: 0.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 0.0,
|
||||
watermark: false,
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
ignore_eos_token: false,
|
||||
max_new_tokens: 0,
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
};
|
||||
(entry, receiver_tx)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append() {
|
||||
let mut state = State::new();
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
assert_eq!(state.entries.len(), 0);
|
||||
|
||||
state.append(entry);
|
||||
|
||||
assert_eq!(state.next_id, 1);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
let (id, _) = state.entries.remove(0).unwrap();
|
||||
assert_eq!(id, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_empty() {
|
||||
let mut state = State::new();
|
||||
|
||||
assert!(state.next_batch(None, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), 1).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_min_size() {
|
||||
let mut state = State::new();
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 2);
|
||||
|
||||
assert_eq!(state.next_id, 2);
|
||||
assert_eq!(state.entries.len(), 0);
|
||||
assert_eq!(state.next_batch_id, 1);
|
||||
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
assert!(state.next_batch(Some(2), 2).is_none());
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
let (id, _) = state.entries.remove(0).unwrap();
|
||||
assert_eq!(id, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_max_size() {
|
||||
let mut state = State::new();
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
|
||||
assert_eq!(state.next_id, 2);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
assert_eq!(state.next_batch_id, 1);
|
||||
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
assert_eq!(batch.id, 1);
|
||||
assert_eq!(batch.size, 2);
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 0);
|
||||
assert_eq!(state.next_batch_id, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new();
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new();
|
||||
|
||||
assert!(queue.next_batch(None, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), 1).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new();
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 2);
|
||||
|
||||
let (entry3, _guard3) = default_entry();
|
||||
queue.append(entry3);
|
||||
|
||||
assert!(queue.next_batch(Some(2), 2).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new();
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
|
||||
let (entry3, _guard3) = default_entry();
|
||||
queue.append(entry3);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
assert_eq!(batch.id, 1);
|
||||
assert_eq!(batch.size, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new();
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
assert!(queue.next_batch(None, 1).await.is_none());
|
||||
}
|
||||
}
|
||||
|
@ -505,6 +505,8 @@ pub async fn run(
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
max_batch_size: usize,
|
||||
max_batch_weight: Option<usize>,
|
||||
max_prefill_weight: Option<usize>,
|
||||
max_waiting_tokens: usize,
|
||||
client: ShardedClient,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
@ -552,6 +554,15 @@ pub async fn run(
|
||||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
// If max batch weight is not set, infer from max batch size and max seq length
|
||||
let max_batch_weight = max_batch_weight
|
||||
.unwrap_or(max_batch_size * max_total_tokens);
|
||||
let max_prefill_weight = max_prefill_weight.unwrap_or_default();
|
||||
|
||||
if max_total_tokens > max_batch_weight {
|
||||
panic!("max_total_tokens cannot be greater than max_batch_weight");
|
||||
}
|
||||
|
||||
// Create state
|
||||
let validation = Validation::new(
|
||||
validation_workers,
|
||||
@ -565,6 +576,8 @@ pub async fn run(
|
||||
client,
|
||||
validation,
|
||||
max_batch_size,
|
||||
max_batch_weight,
|
||||
max_prefill_weight,
|
||||
max_waiting_tokens,
|
||||
max_concurrent_requests,
|
||||
);
|
||||
|
@ -1,3 +1,4 @@
|
||||
use std::cmp::min;
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
/// Payload validation logic
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
@ -69,7 +70,7 @@ impl Validation {
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: u32,
|
||||
) -> Result<String, ValidationError> {
|
||||
) -> Result<(String, usize), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some(sender) = &self.sender {
|
||||
// Create response channel
|
||||
@ -105,7 +106,7 @@ impl Validation {
|
||||
}
|
||||
|
||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||
Ok(inputs)
|
||||
Ok((inputs, input_length))
|
||||
}
|
||||
// Return inputs without validation
|
||||
else {
|
||||
@ -123,7 +124,7 @@ impl Validation {
|
||||
));
|
||||
}
|
||||
|
||||
Ok(inputs)
|
||||
Ok((inputs, min(truncate.unwrap_or(usize::MAX), self.max_input_length)))
|
||||
}
|
||||
}
|
||||
|
||||
@ -238,7 +239,7 @@ impl Validation {
|
||||
.unwrap_or(Ok(None))?;
|
||||
|
||||
// Validate inputs
|
||||
let inputs = self
|
||||
let (inputs, input_length) = self
|
||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||
.await?;
|
||||
|
||||
@ -262,7 +263,7 @@ impl Validation {
|
||||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs,
|
||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||
truncate: input_length as u32,
|
||||
parameters,
|
||||
stopping_parameters,
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user