feat(router): dynamic batch sizing

This commit is contained in:
Nick Hill 2023-04-19 07:58:40 -07:00
parent 709d8936f6
commit ba1aae3e78
6 changed files with 341 additions and 302 deletions

View File

@ -41,6 +41,10 @@ struct Args {
max_total_tokens: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(default_value = None, long, env)]
max_batch_weight: Option<usize>,
#[clap(default_value = None, long, env)]
max_prefill_weight: Option<usize>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
@ -93,6 +97,8 @@ fn main() -> ExitCode {
max_input_length,
max_total_tokens,
max_batch_size,
max_batch_weight,
max_prefill_weight,
max_waiting_tokens,
port,
shard_uds_path,
@ -392,6 +398,16 @@ fn main() -> ExitCode {
model_id,
];
if let Some(max_batch_weight) = max_batch_weight {
argv.push("--max-batch-weight".to_string());
argv.push(max_batch_weight.to_string())
}
if let Some(max_prefill_weight) = max_prefill_weight {
argv.push("--max-batch-weight".to_string());
argv.push(max_prefill_weight.to_string())
}
// Model optional revision
if let Some(ref revision) = revision {
argv.push("--revision".to_string());

View File

@ -15,6 +15,7 @@ use thiserror::Error;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span};
use crate::queue::BatchingConfig;
/// Inference struct
#[derive(Clone)]
@ -40,11 +41,17 @@ impl Infer {
client: ShardedClient,
validation: Validation,
max_batch_size: usize,
max_batch_weight: usize,
max_prefill_weight: usize,
max_waiting_tokens: usize,
max_concurrent_requests: usize,
) -> Self {
// Infer shared state
let queue = Queue::new();
let queue = Queue::new(BatchingConfig {
size_limit: max_batch_size,
weight_limit: max_batch_weight,
prefill_weight_limit: max_prefill_weight,
});
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
@ -52,7 +59,6 @@ impl Infer {
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client,
max_batch_size,
max_waiting_tokens,
queue.clone(),
shared.clone(),
@ -105,6 +111,7 @@ impl Infer {
// Append the request to the queue
self.queue.append(Entry {
request: valid_request,
generated_tokens: 0,
response_tx,
span: Span::current(),
temp_span: None,
@ -232,18 +239,11 @@ impl Infer {
/// Batches requests and sends them to the inference server
async fn batching_task(
mut client: ShardedClient,
max_batch_size: usize,
// max_batch_size: usize,
max_waiting_tokens: usize,
queue: Queue,
shared: Arc<Shared>,
) {
// Minimum batch size after which we try to add more requests
let limit_min_batch_size = if max_batch_size > 1 {
(max_batch_size / 2) as u32
} else {
0
};
// Infinite loop
loop {
// Wait for a notification from the Infer struct
@ -252,8 +252,8 @@ async fn batching_task(
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
let mut cached_batch = prefill(&mut client, batch, &mut entries)
while let (_, Some((mut entries, batch, span))) = queue.next_batch(None).await {
let (mut cached_batch, mut some_completed) = prefill(&mut client, batch, &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
@ -266,21 +266,16 @@ async fn batching_task(
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
// If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size {
let min_size = match waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
_ if waiting_tokens >= max_waiting_tokens => None,
// Minimum size criteria
_ => Some(limit_min_batch_size as usize),
};
// Try to extend batch if its size reduced or enough tokens have elapsed since last one
if some_completed || waiting_tokens >= max_waiting_tokens {
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_size - batch_size as usize)
.await
{
// Try to get a new batch - ownership of entries passed in and out
let (
existing_entries, new_entries
) = queue.next_batch(Some(entries)).await;
entries = existing_entries.unwrap();
if let Some((mut new_entries, new_batch, span)) = new_entries {
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
@ -293,7 +288,7 @@ async fn batching_task(
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
let (new_cached_batch, _) = prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
@ -319,7 +314,7 @@ async fn batching_task(
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries)
(cached_batch, some_completed) = decode(&mut client, batches, &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
@ -334,14 +329,14 @@ async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
) -> (Option<Batch>, bool) {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
match client.prefill(batch).await {
Ok((generations, next_batch)) => {
filter_send_generations(generations, entries);
let some_completed = filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = match next_batch {
@ -360,14 +355,14 @@ async fn prefill(
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
next_batch
(next_batch, some_completed)
}
// If we have an error, we discard the whole batch
Err(err) => {
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
None
(None, true)
}
}
}
@ -377,14 +372,14 @@ async fn decode(
client: &mut ShardedClient,
batches: Vec<Batch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
) -> (Option<Batch>, bool) {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
match client.decode(batches).await {
Ok((generations, next_batch)) => {
filter_send_generations(generations, entries);
let some_completed = filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = match next_batch {
@ -403,7 +398,7 @@ async fn decode(
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
next_batch
(next_batch, some_completed)
}
// If we have an error, we discard the whole batch
Err(err) => {
@ -412,7 +407,7 @@ async fn decode(
}
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
None
(None, true)
}
}
}
@ -431,14 +426,16 @@ fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch>
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
/// Return true if any requests completed
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) -> bool {
let mut some_stopped = false;
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&id)
.get_mut(&id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
@ -451,9 +448,14 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
err
}).unwrap_or(true);
if stopped {
some_stopped = true;
entries.remove(&id).expect("ID not found in entries. This is a bug.");
} else {
// Increment generated token count
entry.generated_tokens += 1;
}
});
return some_stopped;
}
/// Send responses through the `entry` response channel

View File

@ -33,6 +33,10 @@ struct Args {
max_total_tokens: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(default_value = None, long, env)]
max_batch_weight: Option<usize>,
#[clap(default_value = None, long, env)]
max_prefill_weight: Option<usize>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
@ -64,6 +68,8 @@ fn main() -> Result<(), std::io::Error> {
max_input_length,
max_total_tokens,
max_batch_size,
max_batch_weight,
max_prefill_weight,
max_waiting_tokens,
port,
master_shard_uds_path,
@ -169,6 +175,8 @@ fn main() -> Result<(), std::io::Error> {
max_input_length,
max_total_tokens,
max_batch_size,
max_batch_weight,
max_prefill_weight,
max_waiting_tokens,
sharded_client,
tokenizer,

View File

@ -2,8 +2,9 @@ use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use std::collections::{BTreeSet, VecDeque};
use std::ops::Add;
use std::time::Duration;
use text_generation_client::{Batch, Request};
use tokio::sync::oneshot;
use tokio::time::Instant;
@ -14,6 +15,8 @@ use tracing::{info_span, instrument, Span};
pub(crate) struct Entry {
/// Request
pub request: ValidGenerateRequest,
/// Count of tokens generated so far
pub generated_tokens: usize,
/// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: flume::Sender<Result<InferStreamResponse, InferError>>,
/// Span that will live as long as entry
@ -34,12 +37,12 @@ pub(crate) struct Queue {
}
impl Queue {
pub(crate) fn new() -> Self {
pub(crate) fn new(config: BatchingConfig) -> Self {
// Create channel
let (queue_sender, queue_receiver) = flume::unbounded();
// Launch background queue task
tokio::spawn(queue_task(queue_receiver));
tokio::spawn(queue_task(queue_receiver, config));
Self { queue_sender }
}
@ -54,21 +57,19 @@ impl Queue {
.unwrap();
}
// Get the next batch
// Get the next batch - existing batch is returned unchanged
#[instrument(skip(self))]
pub(crate) async fn next_batch(
&self,
min_size: Option<usize>,
max_size: usize,
) -> Option<NextBatch> {
entries: Option<ExistingBatch>,
) -> (Option<ExistingBatch>, Option<NextBatch>) {
// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
// Send next batch command to the background task managing the state
// Unwrap is safe here
self.queue_sender
.send(QueueCommand::NextBatch {
min_size,
max_size,
entries,
response_sender,
span: Span::current(),
})
@ -80,28 +81,37 @@ impl Queue {
}
// Background task responsible of the queue state
async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
let mut state = State::new();
async fn queue_task(receiver: flume::Receiver<QueueCommand>, config: BatchingConfig) {
let mut state = State::new(config);
while let Ok(cmd) = receiver.recv_async().await {
match cmd {
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
QueueCommand::NextBatch {
min_size,
max_size,
entries,
response_sender,
span,
} => span.in_scope(|| {
let next_batch = state.next_batch(min_size, max_size);
response_sender.send(next_batch).unwrap_or(());
let response = state.next_batch(entries);
response_sender.send(response).unwrap_or(());
}),
}
}
}
#[derive(Debug)]
pub(crate) struct BatchingConfig {
pub(crate) size_limit: usize,
pub(crate) weight_limit: usize,
pub(crate) prefill_weight_limit: usize,
}
/// Queue State
#[derive(Debug)]
struct State {
/// Batching configuration
config: BatchingConfig,
/// Queue entries organized in a Vec
entries: VecDeque<(u64, Entry)>,
@ -110,14 +120,46 @@ struct State {
/// Id of the next batch
next_batch_id: u64,
// Remembered size of the last batch, used to determine
// when entries have completed between calls to the
// next_batch function
last_seen_batch_size: usize,
// Index in the queue up to which entries have been
// checked to see if they can fit into the current batch.
// Reset to zero when any existing entries complete
checked_request_count: usize,
/// true if it's known that the current size of the
/// requests in the queue is too small to prefill an add-on batch
buffer_contents_insufficient: bool,
/// Just a constant empty map to reuse
empty_map: ExistingBatch,
}
// Could also make these configurable
/// Longest that requests can be waiting before we ignore the minimum
/// size requirement when adding to a new batch
const MAX_WAITING_DURATION: Duration = Duration::from_secs(1);
/// Maximum difference in arrival time that smaller requests can jump
/// ahead of larger ones in the queue
const CUTOFF_DURATION: Duration = Duration::from_secs(1);
impl State {
fn new() -> Self {
fn new(config: BatchingConfig) -> Self {
Self {
config,
entries: VecDeque::with_capacity(128),
next_id: 0,
next_batch_id: 0,
last_seen_batch_size: 0,
checked_request_count: 0,
buffer_contents_insufficient: false,
empty_map: IntMap::default(),
}
}
@ -134,70 +176,221 @@ impl State {
}
// Get the next batch
fn next_batch(&mut self, min_size: Option<usize>, max_size: usize) -> Option<NextBatch> {
if self.entries.is_empty() {
return None;
fn next_batch(
&mut self, existing_entries_opt: Option<ExistingBatch>,
) -> (Option<ExistingBatch>, Option<NextBatch>) {
// Use ref to empty map in None case to simplify subsequent logic
let existing_entries = existing_entries_opt.as_ref().unwrap_or(&self.empty_map);
let config = &self.config;
let mut total_count = existing_entries.len();
if total_count >= config.size_limit {
// We are already at max batch size
return (existing_entries_opt, None)
}
// Check if we have enough entries
if let Some(min_size) = min_size {
if self.entries.len() < min_size {
return None;
if total_count != self.last_seen_batch_size {
// Reset the count of checked requests if any completed since last check
self.checked_request_count = 0;
self.last_seen_batch_size = total_count
}
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
let queue_len_before = self.entries.len();
self.entries.retain_mut(|(_, entry)| !entry.response_tx.is_disconnected());
if queue_len_before != self.entries.len() {
// Reset the count of checked requests if any in the queue were cancelled since last check
self.checked_request_count = 0;
}
// This will generally be zero, but if no requests have been completed
// since last time, we don't need to reconsider those already checked
let mut checked_up_to_index = self.checked_request_count;
if !existing_entries.is_empty() {
// If we don't have any new requests in the buffer to check
if self.entries.len() <= checked_up_to_index ||
// Or the current buffer isn't large enough to satisfy the min prefill requirement
self.buffer_contents_insufficient && !self.next_entry_waiting_too_long() {
return (existing_entries_opt, None)
}
}
let max_batch_size = min(self.entries.len(), max_size);
// Indices into buffer of entries chosen to add to next batch or remove due to expiry
let mut chosen_indices = vec![];
// Indices to drop due to client cancellation
let mut indices_to_drop = vec![];
let mut btree = None;
let mut time_cutoff = None;
let mut hit_prefill_weight_limit = false;
let mut total_token_count = existing_entries.iter().map(
|(_, e)| e.request.stopping_parameters.max_new_tokens + e.request.truncate
).sum::<u32>() as usize;
let mut prefill_size = 0;
// We first do a read-only pass over the queue to allow skipping over large entries
// that don't fit in the current batch to reach smaller entries that do
let mut queue_index = checked_up_to_index;
'queue_loop: for (entry_id, entry) in self.entries.range(queue_index..) {
if matches!(time_cutoff, Some(t) if entry.queue_time > t) {
break
}
queue_index += 1;
if entry.response_tx.is_disconnected() {
// Eject cancelled entry from queue
indices_to_drop.push(queue_index);
continue
}
// This is the index into the queue after cancelled entries
// have been pruned
checked_up_to_index += 1;
let input_len = entry.request.truncate as usize;
let output_len = entry.request.stopping_parameters.max_new_tokens as usize;
let next_total_token_count = total_token_count + input_len + output_len;
// Avoid more granular analysis if possible
if next_total_token_count > config.weight_limit {
// We aren't sure whether this next request will fit, so populate
// a btree with the current batch of requests, the set of
// requests already evaluated, and this one, and perform more
// granular analysis to verify that the batch shape won't exceed
// the limit at any point
// Allocate btree the first time it's required
let tree = btree.get_or_insert_with(|| {
let mut t = Box::new(BTreeSet::new());
// Populate with records corresponding to all existing and pending entries
let pending = chosen_indices.iter()
.map(|i| self.entries.get(*i).unwrap())
.map(|(eid, e)| (eid, e));
for (eid, e) in existing_entries.iter().chain(pending) {
let generated_count = e.generated_tokens;
t.insert((
e.request.stopping_parameters.max_new_tokens as usize - generated_count,
e.request.truncate as usize + e.generated_tokens,
eid,
));
}
t
});
// Add the current entry
tree.insert((output_len, input_len, entry_id));
// Perform analysis
let mut in_sum = 0;
// Work backwards from longest projected entry
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
let this_ol = *ol;
in_sum += *il;
if this_ol <= output_len {
// Check if we breach max space for this segment
let token_count = in_sum + (bs + 1) * this_ol;
if token_count > config.weight_limit {
// Remove our tuple from the set
tree.remove(&(output_len, input_len, entry_id));
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
continue 'queue_loop
}
}
}
} else if let Some(tree) = btree.as_mut() {
// If we initialized the btree for a prior request, keep it updated
tree.insert((output_len, input_len, entry_id));
}
// Here, we can add this request to the batch without breaching memory limit
if config.prefill_weight_limit > 0 {
// Also check whether adding this request will make the batch of new requests
// too expensive latency-wise to perform in a single forward-pass.
if prefill_size + input_len > config.prefill_weight_limit {
if let Some(tree) = btree.as_mut() {
// Remove our tuple from the set
tree.remove(&(output_len, input_len, entry_id));
hit_prefill_weight_limit = true;
}
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
continue
}
}
total_token_count = next_total_token_count;
prefill_size += input_len;
chosen_indices.push(queue_index - 1);
total_count += 1;
if total_count >= config.size_limit {
break
}
}
// Drop any cancelled requests
if !indices_to_drop.is_empty() {
indices_to_drop.iter().for_each(|i| {
self.entries.remove(*i);
});
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
}
let next_batch_size = chosen_indices.len();
if next_batch_size == 0 {
// This gets reset to zero when any requests in the existing batch are removed
self.checked_request_count = checked_up_to_index;
return (existing_entries_opt, None)
}
self.checked_request_count = 0;
if !hit_prefill_weight_limit && !existing_entries.is_empty() {
// If this is to be added to an existing batch, ensure it meets urgency or size
// requirements to avoid too frequent prefills
if !self.next_entry_waiting_too_long() {
if total_token_count < config.weight_limit / 2 {
// Don't add this new batch yet because it's not large enough
self.checked_request_count = checked_up_to_index;
self.buffer_contents_insufficient = true;
return (existing_entries_opt, None)
}
}
}
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size);
next_batch_span.follows_from(&Span::current());
let mut batch_requests = Vec::with_capacity(max_batch_size);
let mut batch_entries =
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
// Iterate on buffer
while let Some((id, mut entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_disconnected() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
continue;
}
IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default());
let some_now = Some(Instant::now());
let batch_requests = chosen_indices.iter().enumerate().map(|(i, index)| {
let (id, mut entry) = self.entries.remove(index - i).expect("bug");
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
batch_requests.push(Request {
let request = Request {
id,
inputs: entry.request.inputs.clone(),
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
});
};
// Set batch_time
entry.batch_time = Some(Instant::now());
entry.batch_time = some_now;
// Insert in batch_entries IntMap
batch_entries.insert(id, entry);
if batch_requests.len() == max_batch_size {
// We have enough requests in the batch
break;
}
}
request
}).collect::<Vec<Request>>();
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
// Maybe all entries were dropped because their channel were closed
if batch_requests.is_empty() {
return None;
}
// Final batch size once we dropped entries
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);
@ -209,224 +402,30 @@ impl State {
};
// Increment batch id
self.next_batch_id += 1;
self.buffer_contents_insufficient = false;
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
Some((batch_entries, batch, next_batch_span))
(existing_entries_opt, Some((batch_entries, batch, next_batch_span)))
}
/// Returns true if the entry at the front of the queue has been waiting for longer
/// than MAX_WAITING_DURATION
fn next_entry_waiting_too_long(&self) -> bool {
matches!(
self.entries.front(), Some((_, e)) if e.queue_time.elapsed() > MAX_WAITING_DURATION
)
}
}
type ExistingBatch = IntMap<u64, Entry>;
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
#[derive(Debug)]
enum QueueCommand {
Append(Entry, Span),
NextBatch {
min_size: Option<usize>,
max_size: usize,
response_sender: oneshot::Sender<Option<NextBatch>>,
entries: Option<ExistingBatch>,
response_sender: oneshot::Sender<(Option<ExistingBatch>, Option<NextBatch>)>,
span: Span,
},
}
#[cfg(test)]
mod tests {
use super::*;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use tracing::info_span;
fn default_entry() -> (
Entry,
flume::Receiver<Result<InferStreamResponse, InferError>>,
) {
let (response_tx, receiver_tx) = flume::unbounded();
let entry = Entry {
request: ValidGenerateRequest {
inputs: "".to_string(),
truncate: 0,
parameters: NextTokenChooserParameters {
temperature: 0.0,
top_k: 0,
top_p: 0.0,
typical_p: 0.0,
do_sample: false,
seed: 0,
repetition_penalty: 0.0,
watermark: false,
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,
max_new_tokens: 0,
stop_sequences: vec![],
},
},
response_tx,
span: info_span!("entry"),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
};
(entry, receiver_tx)
}
#[test]
fn test_append() {
let mut state = State::new();
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
assert_eq!(state.entries.len(), 0);
state.append(entry);
assert_eq!(state.next_id, 1);
assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0).unwrap();
assert_eq!(id, 0);
}
#[test]
fn test_next_batch_empty() {
let mut state = State::new();
assert!(state.next_batch(None, 1).is_none());
assert!(state.next_batch(Some(1), 1).is_none());
}
#[test]
fn test_next_batch_min_size() {
let mut state = State::new();
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
assert!(entries.get(&0).unwrap().batch_time.is_some());
assert!(entries.get(&1).unwrap().batch_time.is_some());
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2);
assert_eq!(state.next_id, 2);
assert_eq!(state.entries.len(), 0);
assert_eq!(state.next_batch_id, 1);
let (entry3, _guard3) = default_entry();
state.append(entry3);
assert!(state.next_batch(Some(2), 2).is_none());
assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0).unwrap();
assert_eq!(id, 2);
}
#[test]
fn test_next_batch_max_size() {
let mut state = State::new();
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 1);
assert_eq!(state.next_id, 2);
assert_eq!(state.entries.len(), 1);
assert_eq!(state.next_batch_id, 1);
let (entry3, _guard3) = default_entry();
state.append(entry3);
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
assert_eq!(batch.id, 1);
assert_eq!(batch.size, 2);
assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 0);
assert_eq!(state.next_batch_id, 2);
}
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new();
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new();
assert!(queue.next_batch(None, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1).await.is_none());
}
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new();
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
assert!(entries.get(&0).unwrap().batch_time.is_some());
assert!(entries.get(&1).unwrap().batch_time.is_some());
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2);
let (entry3, _guard3) = default_entry();
queue.append(entry3);
assert!(queue.next_batch(Some(2), 2).await.is_none());
}
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = Queue::new();
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 1);
let (entry3, _guard3) = default_entry();
queue.append(entry3);
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
assert_eq!(batch.id, 1);
assert_eq!(batch.size, 2);
}
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new();
let (entry, _) = default_entry();
queue.append(entry);
assert!(queue.next_batch(None, 1).await.is_none());
}
}

View File

@ -505,6 +505,8 @@ pub async fn run(
max_input_length: usize,
max_total_tokens: usize,
max_batch_size: usize,
max_batch_weight: Option<usize>,
max_prefill_weight: Option<usize>,
max_waiting_tokens: usize,
client: ShardedClient,
tokenizer: Option<Tokenizer>,
@ -552,6 +554,15 @@ pub async fn run(
)]
struct ApiDoc;
// If max batch weight is not set, infer from max batch size and max seq length
let max_batch_weight = max_batch_weight
.unwrap_or(max_batch_size * max_total_tokens);
let max_prefill_weight = max_prefill_weight.unwrap_or_default();
if max_total_tokens > max_batch_weight {
panic!("max_total_tokens cannot be greater than max_batch_weight");
}
// Create state
let validation = Validation::new(
validation_workers,
@ -565,6 +576,8 @@ pub async fn run(
client,
validation,
max_batch_size,
max_batch_weight,
max_prefill_weight,
max_waiting_tokens,
max_concurrent_requests,
);

View File

@ -1,3 +1,4 @@
use std::cmp::min;
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
/// Payload validation logic
use crate::{GenerateParameters, GenerateRequest};
@ -69,7 +70,7 @@ impl Validation {
inputs: String,
truncate: Option<usize>,
max_new_tokens: u32,
) -> Result<String, ValidationError> {
) -> Result<(String, usize), ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.sender {
// Create response channel
@ -105,7 +106,7 @@ impl Validation {
}
metrics::histogram!("tgi_request_input_length", input_length as f64);
Ok(inputs)
Ok((inputs, input_length))
}
// Return inputs without validation
else {
@ -123,7 +124,7 @@ impl Validation {
));
}
Ok(inputs)
Ok((inputs, min(truncate.unwrap_or(usize::MAX), self.max_input_length)))
}
}
@ -238,7 +239,7 @@ impl Validation {
.unwrap_or(Ok(None))?;
// Validate inputs
let inputs = self
let (inputs, input_length) = self
.validate_input(request.inputs, truncate, max_new_tokens)
.await?;
@ -262,7 +263,7 @@ impl Validation {
Ok(ValidGenerateRequest {
inputs,
truncate: truncate.unwrap_or(self.max_input_length) as u32,
truncate: input_length as u32,
parameters,
stopping_parameters,
})