mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
feat(router): dynamic batch sizing
This commit is contained in:
parent
709d8936f6
commit
ba1aae3e78
@ -41,6 +41,10 @@ struct Args {
|
|||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
#[clap(default_value = "32", long, env)]
|
#[clap(default_value = "32", long, env)]
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
|
#[clap(default_value = None, long, env)]
|
||||||
|
max_batch_weight: Option<usize>,
|
||||||
|
#[clap(default_value = None, long, env)]
|
||||||
|
max_prefill_weight: Option<usize>,
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
@ -93,6 +97,8 @@ fn main() -> ExitCode {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
max_batch_weight,
|
||||||
|
max_prefill_weight,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
port,
|
port,
|
||||||
shard_uds_path,
|
shard_uds_path,
|
||||||
@ -392,6 +398,16 @@ fn main() -> ExitCode {
|
|||||||
model_id,
|
model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
if let Some(max_batch_weight) = max_batch_weight {
|
||||||
|
argv.push("--max-batch-weight".to_string());
|
||||||
|
argv.push(max_batch_weight.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(max_prefill_weight) = max_prefill_weight {
|
||||||
|
argv.push("--max-batch-weight".to_string());
|
||||||
|
argv.push(max_prefill_weight.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(ref revision) = revision {
|
if let Some(ref revision) = revision {
|
||||||
argv.push("--revision".to_string());
|
argv.push("--revision".to_string());
|
||||||
|
@ -15,6 +15,7 @@ use thiserror::Error;
|
|||||||
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
use crate::queue::BatchingConfig;
|
||||||
|
|
||||||
/// Inference struct
|
/// Inference struct
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -40,11 +41,17 @@ impl Infer {
|
|||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
|
max_batch_weight: usize,
|
||||||
|
max_prefill_weight: usize,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new();
|
let queue = Queue::new(BatchingConfig {
|
||||||
|
size_limit: max_batch_size,
|
||||||
|
weight_limit: max_batch_weight,
|
||||||
|
prefill_weight_limit: max_prefill_weight,
|
||||||
|
});
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
@ -52,7 +59,6 @@ impl Infer {
|
|||||||
// Spawn batching background task that contains all the inference logic
|
// Spawn batching background task that contains all the inference logic
|
||||||
tokio::spawn(batching_task(
|
tokio::spawn(batching_task(
|
||||||
client,
|
client,
|
||||||
max_batch_size,
|
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
shared.clone(),
|
shared.clone(),
|
||||||
@ -105,6 +111,7 @@ impl Infer {
|
|||||||
// Append the request to the queue
|
// Append the request to the queue
|
||||||
self.queue.append(Entry {
|
self.queue.append(Entry {
|
||||||
request: valid_request,
|
request: valid_request,
|
||||||
|
generated_tokens: 0,
|
||||||
response_tx,
|
response_tx,
|
||||||
span: Span::current(),
|
span: Span::current(),
|
||||||
temp_span: None,
|
temp_span: None,
|
||||||
@ -232,18 +239,11 @@ impl Infer {
|
|||||||
/// Batches requests and sends them to the inference server
|
/// Batches requests and sends them to the inference server
|
||||||
async fn batching_task(
|
async fn batching_task(
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
max_batch_size: usize,
|
// max_batch_size: usize,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
) {
|
) {
|
||||||
// Minimum batch size after which we try to add more requests
|
|
||||||
let limit_min_batch_size = if max_batch_size > 1 {
|
|
||||||
(max_batch_size / 2) as u32
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
|
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
loop {
|
loop {
|
||||||
// Wait for a notification from the Infer struct
|
// Wait for a notification from the Infer struct
|
||||||
@ -252,8 +252,8 @@ async fn batching_task(
|
|||||||
// Get the next batch from the queue
|
// Get the next batch from the queue
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the queue
|
// waiting in the queue
|
||||||
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
|
while let (_, Some((mut entries, batch, span))) = queue.next_batch(None).await {
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
let (mut cached_batch, mut some_completed) = prefill(&mut client, batch, &mut entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
@ -266,21 +266,16 @@ async fn batching_task(
|
|||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
||||||
|
|
||||||
// If the current batch is too small, we try to add more requests to it
|
// Try to extend batch if its size reduced or enough tokens have elapsed since last one
|
||||||
if batch_size <= limit_min_batch_size {
|
if some_completed || waiting_tokens >= max_waiting_tokens {
|
||||||
let min_size = match waiting_tokens {
|
|
||||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
|
||||||
// to add a new batch even though its size might be small
|
|
||||||
_ if waiting_tokens >= max_waiting_tokens => None,
|
|
||||||
// Minimum size criteria
|
|
||||||
_ => Some(limit_min_batch_size as usize),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch - ownership of entries passed in and out
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
let (
|
||||||
.next_batch(min_size, max_batch_size - batch_size as usize)
|
existing_entries, new_entries
|
||||||
.await
|
) = queue.next_batch(Some(entries)).await;
|
||||||
{
|
entries = existing_entries.unwrap();
|
||||||
|
|
||||||
|
if let Some((mut new_entries, new_batch, span)) = new_entries {
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
// Create a new span to add the info that this entry is waiting
|
// Create a new span to add the info that this entry is waiting
|
||||||
// because a new batch is being computed
|
// because a new batch is being computed
|
||||||
@ -293,7 +288,7 @@ async fn batching_task(
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
let (new_cached_batch, _) = prefill(&mut client, new_batch, &mut new_entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
// Reset waiting counter
|
// Reset waiting counter
|
||||||
@ -319,7 +314,7 @@ async fn batching_task(
|
|||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
});
|
});
|
||||||
|
|
||||||
cached_batch = decode(&mut client, batches, &mut entries)
|
(cached_batch, some_completed) = decode(&mut client, batches, &mut entries)
|
||||||
.instrument(next_batch_span)
|
.instrument(next_batch_span)
|
||||||
.await;
|
.await;
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
@ -334,14 +329,14 @@ async fn prefill(
|
|||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<Batch> {
|
) -> (Option<Batch>, bool) {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
filter_send_generations(generations, entries);
|
let some_completed = filter_send_generations(generations, entries);
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = match next_batch {
|
let next_batch = match next_batch {
|
||||||
@ -360,14 +355,14 @@ async fn prefill(
|
|||||||
|
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||||
next_batch
|
(next_batch, some_completed)
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
let _ = client.clear_cache(Some(batch_id)).await;
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||||
None
|
(None, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -377,14 +372,14 @@ async fn decode(
|
|||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<Batch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<Batch> {
|
) -> (Option<Batch>, bool) {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||||
|
|
||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
filter_send_generations(generations, entries);
|
let some_completed = filter_send_generations(generations, entries);
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = match next_batch {
|
let next_batch = match next_batch {
|
||||||
@ -403,7 +398,7 @@ async fn decode(
|
|||||||
|
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||||
next_batch
|
(next_batch, some_completed)
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
@ -412,7 +407,7 @@ async fn decode(
|
|||||||
}
|
}
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
||||||
None
|
(None, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -431,14 +426,16 @@ fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch>
|
|||||||
|
|
||||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||||
/// and filter entries
|
/// and filter entries
|
||||||
|
/// Return true if any requests completed
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) -> bool {
|
||||||
|
let mut some_stopped = false;
|
||||||
generations.into_iter().for_each(|generation| {
|
generations.into_iter().for_each(|generation| {
|
||||||
let id = generation.request_id;
|
let id = generation.request_id;
|
||||||
// Get entry
|
// Get entry
|
||||||
// We can `expect` here as the request id should always be in the entries
|
// We can `expect` here as the request id should always be in the entries
|
||||||
let entry = entries
|
let entry = entries
|
||||||
.get(&id)
|
.get_mut(&id)
|
||||||
.expect("ID not found in entries. This is a bug.");
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
@ -451,9 +448,14 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
err
|
err
|
||||||
}).unwrap_or(true);
|
}).unwrap_or(true);
|
||||||
if stopped {
|
if stopped {
|
||||||
|
some_stopped = true;
|
||||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
|
} else {
|
||||||
|
// Increment generated token count
|
||||||
|
entry.generated_tokens += 1;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
return some_stopped;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send responses through the `entry` response channel
|
/// Send responses through the `entry` response channel
|
||||||
|
@ -33,6 +33,10 @@ struct Args {
|
|||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
#[clap(default_value = "32", long, env)]
|
#[clap(default_value = "32", long, env)]
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
|
#[clap(default_value = None, long, env)]
|
||||||
|
max_batch_weight: Option<usize>,
|
||||||
|
#[clap(default_value = None, long, env)]
|
||||||
|
max_prefill_weight: Option<usize>,
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
@ -64,6 +68,8 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
max_batch_weight,
|
||||||
|
max_prefill_weight,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
port,
|
port,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
@ -169,6 +175,8 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
max_batch_weight,
|
||||||
|
max_prefill_weight,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
sharded_client,
|
sharded_client,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
@ -2,8 +2,9 @@ use crate::infer::InferError;
|
|||||||
use crate::infer::InferStreamResponse;
|
use crate::infer::InferStreamResponse;
|
||||||
use crate::validation::ValidGenerateRequest;
|
use crate::validation::ValidGenerateRequest;
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::min;
|
use std::collections::{BTreeSet, VecDeque};
|
||||||
use std::collections::VecDeque;
|
use std::ops::Add;
|
||||||
|
use std::time::Duration;
|
||||||
use text_generation_client::{Batch, Request};
|
use text_generation_client::{Batch, Request};
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
@ -14,6 +15,8 @@ use tracing::{info_span, instrument, Span};
|
|||||||
pub(crate) struct Entry {
|
pub(crate) struct Entry {
|
||||||
/// Request
|
/// Request
|
||||||
pub request: ValidGenerateRequest,
|
pub request: ValidGenerateRequest,
|
||||||
|
/// Count of tokens generated so far
|
||||||
|
pub generated_tokens: usize,
|
||||||
/// Response sender to communicate between the Infer struct and the batching_task
|
/// Response sender to communicate between the Infer struct and the batching_task
|
||||||
pub response_tx: flume::Sender<Result<InferStreamResponse, InferError>>,
|
pub response_tx: flume::Sender<Result<InferStreamResponse, InferError>>,
|
||||||
/// Span that will live as long as entry
|
/// Span that will live as long as entry
|
||||||
@ -34,12 +37,12 @@ pub(crate) struct Queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Queue {
|
impl Queue {
|
||||||
pub(crate) fn new() -> Self {
|
pub(crate) fn new(config: BatchingConfig) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||||
|
|
||||||
// Launch background queue task
|
// Launch background queue task
|
||||||
tokio::spawn(queue_task(queue_receiver));
|
tokio::spawn(queue_task(queue_receiver, config));
|
||||||
|
|
||||||
Self { queue_sender }
|
Self { queue_sender }
|
||||||
}
|
}
|
||||||
@ -54,21 +57,19 @@ impl Queue {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch - existing batch is returned unchanged
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub(crate) async fn next_batch(
|
pub(crate) async fn next_batch(
|
||||||
&self,
|
&self,
|
||||||
min_size: Option<usize>,
|
entries: Option<ExistingBatch>,
|
||||||
max_size: usize,
|
) -> (Option<ExistingBatch>, Option<NextBatch>) {
|
||||||
) -> Option<NextBatch> {
|
|
||||||
// Create response channel
|
// Create response channel
|
||||||
let (response_sender, response_receiver) = oneshot::channel();
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
// Send next batch command to the background task managing the state
|
// Send next batch command to the background task managing the state
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
self.queue_sender
|
self.queue_sender
|
||||||
.send(QueueCommand::NextBatch {
|
.send(QueueCommand::NextBatch {
|
||||||
min_size,
|
entries,
|
||||||
max_size,
|
|
||||||
response_sender,
|
response_sender,
|
||||||
span: Span::current(),
|
span: Span::current(),
|
||||||
})
|
})
|
||||||
@ -80,28 +81,37 @@ impl Queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Background task responsible of the queue state
|
// Background task responsible of the queue state
|
||||||
async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
|
async fn queue_task(receiver: flume::Receiver<QueueCommand>, config: BatchingConfig) {
|
||||||
let mut state = State::new();
|
let mut state = State::new(config);
|
||||||
|
|
||||||
while let Ok(cmd) = receiver.recv_async().await {
|
while let Ok(cmd) = receiver.recv_async().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
|
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
|
||||||
QueueCommand::NextBatch {
|
QueueCommand::NextBatch {
|
||||||
min_size,
|
entries,
|
||||||
max_size,
|
|
||||||
response_sender,
|
response_sender,
|
||||||
span,
|
span,
|
||||||
} => span.in_scope(|| {
|
} => span.in_scope(|| {
|
||||||
let next_batch = state.next_batch(min_size, max_size);
|
let response = state.next_batch(entries);
|
||||||
response_sender.send(next_batch).unwrap_or(());
|
response_sender.send(response).unwrap_or(());
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct BatchingConfig {
|
||||||
|
pub(crate) size_limit: usize,
|
||||||
|
pub(crate) weight_limit: usize,
|
||||||
|
pub(crate) prefill_weight_limit: usize,
|
||||||
|
}
|
||||||
|
|
||||||
/// Queue State
|
/// Queue State
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct State {
|
struct State {
|
||||||
|
/// Batching configuration
|
||||||
|
config: BatchingConfig,
|
||||||
|
|
||||||
/// Queue entries organized in a Vec
|
/// Queue entries organized in a Vec
|
||||||
entries: VecDeque<(u64, Entry)>,
|
entries: VecDeque<(u64, Entry)>,
|
||||||
|
|
||||||
@ -110,14 +120,46 @@ struct State {
|
|||||||
|
|
||||||
/// Id of the next batch
|
/// Id of the next batch
|
||||||
next_batch_id: u64,
|
next_batch_id: u64,
|
||||||
|
|
||||||
|
// Remembered size of the last batch, used to determine
|
||||||
|
// when entries have completed between calls to the
|
||||||
|
// next_batch function
|
||||||
|
last_seen_batch_size: usize,
|
||||||
|
|
||||||
|
// Index in the queue up to which entries have been
|
||||||
|
// checked to see if they can fit into the current batch.
|
||||||
|
// Reset to zero when any existing entries complete
|
||||||
|
checked_request_count: usize,
|
||||||
|
|
||||||
|
/// true if it's known that the current size of the
|
||||||
|
/// requests in the queue is too small to prefill an add-on batch
|
||||||
|
buffer_contents_insufficient: bool,
|
||||||
|
|
||||||
|
/// Just a constant empty map to reuse
|
||||||
|
empty_map: ExistingBatch,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Could also make these configurable
|
||||||
|
|
||||||
|
/// Longest that requests can be waiting before we ignore the minimum
|
||||||
|
/// size requirement when adding to a new batch
|
||||||
|
const MAX_WAITING_DURATION: Duration = Duration::from_secs(1);
|
||||||
|
|
||||||
|
/// Maximum difference in arrival time that smaller requests can jump
|
||||||
|
/// ahead of larger ones in the queue
|
||||||
|
const CUTOFF_DURATION: Duration = Duration::from_secs(1);
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
fn new() -> Self {
|
fn new(config: BatchingConfig) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
config,
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
|
last_seen_batch_size: 0,
|
||||||
|
checked_request_count: 0,
|
||||||
|
buffer_contents_insufficient: false,
|
||||||
|
empty_map: IntMap::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,70 +176,221 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
fn next_batch(&mut self, min_size: Option<usize>, max_size: usize) -> Option<NextBatch> {
|
fn next_batch(
|
||||||
if self.entries.is_empty() {
|
&mut self, existing_entries_opt: Option<ExistingBatch>,
|
||||||
return None;
|
) -> (Option<ExistingBatch>, Option<NextBatch>) {
|
||||||
|
|
||||||
|
// Use ref to empty map in None case to simplify subsequent logic
|
||||||
|
let existing_entries = existing_entries_opt.as_ref().unwrap_or(&self.empty_map);
|
||||||
|
|
||||||
|
let config = &self.config;
|
||||||
|
let mut total_count = existing_entries.len();
|
||||||
|
if total_count >= config.size_limit {
|
||||||
|
// We are already at max batch size
|
||||||
|
return (existing_entries_opt, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if we have enough entries
|
if total_count != self.last_seen_batch_size {
|
||||||
if let Some(min_size) = min_size {
|
// Reset the count of checked requests if any completed since last check
|
||||||
if self.entries.len() < min_size {
|
self.checked_request_count = 0;
|
||||||
return None;
|
self.last_seen_batch_size = total_count
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
|
// was dropped by the client)
|
||||||
|
let queue_len_before = self.entries.len();
|
||||||
|
self.entries.retain_mut(|(_, entry)| !entry.response_tx.is_disconnected());
|
||||||
|
if queue_len_before != self.entries.len() {
|
||||||
|
// Reset the count of checked requests if any in the queue were cancelled since last check
|
||||||
|
self.checked_request_count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This will generally be zero, but if no requests have been completed
|
||||||
|
// since last time, we don't need to reconsider those already checked
|
||||||
|
let mut checked_up_to_index = self.checked_request_count;
|
||||||
|
|
||||||
|
if !existing_entries.is_empty() {
|
||||||
|
// If we don't have any new requests in the buffer to check
|
||||||
|
if self.entries.len() <= checked_up_to_index ||
|
||||||
|
// Or the current buffer isn't large enough to satisfy the min prefill requirement
|
||||||
|
self.buffer_contents_insufficient && !self.next_entry_waiting_too_long() {
|
||||||
|
return (existing_entries_opt, None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let max_batch_size = min(self.entries.len(), max_size);
|
// Indices into buffer of entries chosen to add to next batch or remove due to expiry
|
||||||
|
let mut chosen_indices = vec![];
|
||||||
|
// Indices to drop due to client cancellation
|
||||||
|
let mut indices_to_drop = vec![];
|
||||||
|
let mut btree = None;
|
||||||
|
let mut time_cutoff = None;
|
||||||
|
let mut hit_prefill_weight_limit = false;
|
||||||
|
|
||||||
|
let mut total_token_count = existing_entries.iter().map(
|
||||||
|
|(_, e)| e.request.stopping_parameters.max_new_tokens + e.request.truncate
|
||||||
|
).sum::<u32>() as usize;
|
||||||
|
|
||||||
|
let mut prefill_size = 0;
|
||||||
|
// We first do a read-only pass over the queue to allow skipping over large entries
|
||||||
|
// that don't fit in the current batch to reach smaller entries that do
|
||||||
|
let mut queue_index = checked_up_to_index;
|
||||||
|
'queue_loop: for (entry_id, entry) in self.entries.range(queue_index..) {
|
||||||
|
if matches!(time_cutoff, Some(t) if entry.queue_time > t) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
queue_index += 1;
|
||||||
|
if entry.response_tx.is_disconnected() {
|
||||||
|
// Eject cancelled entry from queue
|
||||||
|
indices_to_drop.push(queue_index);
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// This is the index into the queue after cancelled entries
|
||||||
|
// have been pruned
|
||||||
|
checked_up_to_index += 1;
|
||||||
|
|
||||||
|
let input_len = entry.request.truncate as usize;
|
||||||
|
let output_len = entry.request.stopping_parameters.max_new_tokens as usize;
|
||||||
|
let next_total_token_count = total_token_count + input_len + output_len;
|
||||||
|
|
||||||
|
// Avoid more granular analysis if possible
|
||||||
|
if next_total_token_count > config.weight_limit {
|
||||||
|
// We aren't sure whether this next request will fit, so populate
|
||||||
|
// a btree with the current batch of requests, the set of
|
||||||
|
// requests already evaluated, and this one, and perform more
|
||||||
|
// granular analysis to verify that the batch shape won't exceed
|
||||||
|
// the limit at any point
|
||||||
|
|
||||||
|
// Allocate btree the first time it's required
|
||||||
|
let tree = btree.get_or_insert_with(|| {
|
||||||
|
let mut t = Box::new(BTreeSet::new());
|
||||||
|
// Populate with records corresponding to all existing and pending entries
|
||||||
|
let pending = chosen_indices.iter()
|
||||||
|
.map(|i| self.entries.get(*i).unwrap())
|
||||||
|
.map(|(eid, e)| (eid, e));
|
||||||
|
for (eid, e) in existing_entries.iter().chain(pending) {
|
||||||
|
let generated_count = e.generated_tokens;
|
||||||
|
t.insert((
|
||||||
|
e.request.stopping_parameters.max_new_tokens as usize - generated_count,
|
||||||
|
e.request.truncate as usize + e.generated_tokens,
|
||||||
|
eid,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
t
|
||||||
|
});
|
||||||
|
// Add the current entry
|
||||||
|
tree.insert((output_len, input_len, entry_id));
|
||||||
|
|
||||||
|
// Perform analysis
|
||||||
|
let mut in_sum = 0;
|
||||||
|
// Work backwards from longest projected entry
|
||||||
|
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
|
||||||
|
let this_ol = *ol;
|
||||||
|
in_sum += *il;
|
||||||
|
if this_ol <= output_len {
|
||||||
|
// Check if we breach max space for this segment
|
||||||
|
let token_count = in_sum + (bs + 1) * this_ol;
|
||||||
|
if token_count > config.weight_limit {
|
||||||
|
// Remove our tuple from the set
|
||||||
|
tree.remove(&(output_len, input_len, entry_id));
|
||||||
|
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
|
||||||
|
continue 'queue_loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if let Some(tree) = btree.as_mut() {
|
||||||
|
// If we initialized the btree for a prior request, keep it updated
|
||||||
|
tree.insert((output_len, input_len, entry_id));
|
||||||
|
}
|
||||||
|
// Here, we can add this request to the batch without breaching memory limit
|
||||||
|
|
||||||
|
if config.prefill_weight_limit > 0 {
|
||||||
|
// Also check whether adding this request will make the batch of new requests
|
||||||
|
// too expensive latency-wise to perform in a single forward-pass.
|
||||||
|
if prefill_size + input_len > config.prefill_weight_limit {
|
||||||
|
if let Some(tree) = btree.as_mut() {
|
||||||
|
// Remove our tuple from the set
|
||||||
|
tree.remove(&(output_len, input_len, entry_id));
|
||||||
|
hit_prefill_weight_limit = true;
|
||||||
|
}
|
||||||
|
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
total_token_count = next_total_token_count;
|
||||||
|
prefill_size += input_len;
|
||||||
|
|
||||||
|
chosen_indices.push(queue_index - 1);
|
||||||
|
total_count += 1;
|
||||||
|
if total_count >= config.size_limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop any cancelled requests
|
||||||
|
if !indices_to_drop.is_empty() {
|
||||||
|
indices_to_drop.iter().for_each(|i| {
|
||||||
|
self.entries.remove(*i);
|
||||||
|
});
|
||||||
|
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
let next_batch_size = chosen_indices.len();
|
||||||
|
if next_batch_size == 0 {
|
||||||
|
// This gets reset to zero when any requests in the existing batch are removed
|
||||||
|
self.checked_request_count = checked_up_to_index;
|
||||||
|
return (existing_entries_opt, None)
|
||||||
|
}
|
||||||
|
self.checked_request_count = 0;
|
||||||
|
|
||||||
|
if !hit_prefill_weight_limit && !existing_entries.is_empty() {
|
||||||
|
// If this is to be added to an existing batch, ensure it meets urgency or size
|
||||||
|
// requirements to avoid too frequent prefills
|
||||||
|
if !self.next_entry_waiting_too_long() {
|
||||||
|
if total_token_count < config.weight_limit / 2 {
|
||||||
|
// Don't add this new batch yet because it's not large enough
|
||||||
|
self.checked_request_count = checked_up_to_index;
|
||||||
|
self.buffer_contents_insufficient = true;
|
||||||
|
return (existing_entries_opt, None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create span for this batch to add context to inference calls
|
// Create span for this batch to add context to inference calls
|
||||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||||
next_batch_span.follows_from(&Span::current());
|
next_batch_span.follows_from(&Span::current());
|
||||||
|
|
||||||
let mut batch_requests = Vec::with_capacity(max_batch_size);
|
|
||||||
let mut batch_entries =
|
let mut batch_entries =
|
||||||
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
|
IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default());
|
||||||
|
|
||||||
// Iterate on buffer
|
|
||||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
|
||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
|
||||||
// was dropped by the client)
|
|
||||||
if entry.response_tx.is_disconnected() {
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
let some_now = Some(Instant::now());
|
||||||
|
let batch_requests = chosen_indices.iter().enumerate().map(|(i, index)| {
|
||||||
|
let (id, mut entry) = self.entries.remove(index - i).expect("bug");
|
||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
let entry_batch_span =
|
||||||
|
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
||||||
// Add relationships
|
// Add relationships
|
||||||
next_batch_span.follows_from(&entry_batch_span);
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
|
||||||
batch_requests.push(Request {
|
let request = Request {
|
||||||
id,
|
id,
|
||||||
inputs: entry.request.inputs.clone(),
|
inputs: entry.request.inputs.clone(),
|
||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
parameters: Some(entry.request.parameters.clone()),
|
parameters: Some(entry.request.parameters.clone()),
|
||||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||||
});
|
};
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = some_now;
|
||||||
// Insert in batch_entries IntMap
|
// Insert in batch_entries IntMap
|
||||||
batch_entries.insert(id, entry);
|
batch_entries.insert(id, entry);
|
||||||
|
request
|
||||||
if batch_requests.len() == max_batch_size {
|
}).collect::<Vec<Request>>();
|
||||||
// We have enough requests in the batch
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
|
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
|
||||||
|
|
||||||
// Maybe all entries were dropped because their channel were closed
|
|
||||||
if batch_requests.is_empty() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final batch size once we dropped entries
|
// Final batch size once we dropped entries
|
||||||
let size = batch_requests.len() as u32;
|
let size = batch_requests.len() as u32;
|
||||||
next_batch_span.record("batch_size", size);
|
next_batch_span.record("batch_size", size);
|
||||||
@ -209,224 +402,30 @@ impl State {
|
|||||||
};
|
};
|
||||||
// Increment batch id
|
// Increment batch id
|
||||||
self.next_batch_id += 1;
|
self.next_batch_id += 1;
|
||||||
|
self.buffer_contents_insufficient = false;
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
||||||
Some((batch_entries, batch, next_batch_span))
|
(existing_entries_opt, Some((batch_entries, batch, next_batch_span)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if the entry at the front of the queue has been waiting for longer
|
||||||
|
/// than MAX_WAITING_DURATION
|
||||||
|
fn next_entry_waiting_too_long(&self) -> bool {
|
||||||
|
matches!(
|
||||||
|
self.entries.front(), Some((_, e)) if e.queue_time.elapsed() > MAX_WAITING_DURATION
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ExistingBatch = IntMap<u64, Entry>;
|
||||||
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
|
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum QueueCommand {
|
enum QueueCommand {
|
||||||
Append(Entry, Span),
|
Append(Entry, Span),
|
||||||
NextBatch {
|
NextBatch {
|
||||||
min_size: Option<usize>,
|
entries: Option<ExistingBatch>,
|
||||||
max_size: usize,
|
response_sender: oneshot::Sender<(Option<ExistingBatch>, Option<NextBatch>)>,
|
||||||
response_sender: oneshot::Sender<Option<NextBatch>>,
|
|
||||||
span: Span,
|
span: Span,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
|
||||||
use tracing::info_span;
|
|
||||||
|
|
||||||
fn default_entry() -> (
|
|
||||||
Entry,
|
|
||||||
flume::Receiver<Result<InferStreamResponse, InferError>>,
|
|
||||||
) {
|
|
||||||
let (response_tx, receiver_tx) = flume::unbounded();
|
|
||||||
|
|
||||||
let entry = Entry {
|
|
||||||
request: ValidGenerateRequest {
|
|
||||||
inputs: "".to_string(),
|
|
||||||
truncate: 0,
|
|
||||||
parameters: NextTokenChooserParameters {
|
|
||||||
temperature: 0.0,
|
|
||||||
top_k: 0,
|
|
||||||
top_p: 0.0,
|
|
||||||
typical_p: 0.0,
|
|
||||||
do_sample: false,
|
|
||||||
seed: 0,
|
|
||||||
repetition_penalty: 0.0,
|
|
||||||
watermark: false,
|
|
||||||
},
|
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
|
||||||
ignore_eos_token: false,
|
|
||||||
max_new_tokens: 0,
|
|
||||||
stop_sequences: vec![],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
response_tx,
|
|
||||||
span: info_span!("entry"),
|
|
||||||
temp_span: None,
|
|
||||||
queue_time: Instant::now(),
|
|
||||||
batch_time: None,
|
|
||||||
};
|
|
||||||
(entry, receiver_tx)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_append() {
|
|
||||||
let mut state = State::new();
|
|
||||||
let (entry, _guard) = default_entry();
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
|
||||||
assert_eq!(state.entries.len(), 0);
|
|
||||||
|
|
||||||
state.append(entry);
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 1);
|
|
||||||
assert_eq!(state.entries.len(), 1);
|
|
||||||
let (id, _) = state.entries.remove(0).unwrap();
|
|
||||||
assert_eq!(id, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_next_batch_empty() {
|
|
||||||
let mut state = State::new();
|
|
||||||
|
|
||||||
assert!(state.next_batch(None, 1).is_none());
|
|
||||||
assert!(state.next_batch(Some(1), 1).is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_next_batch_min_size() {
|
|
||||||
let mut state = State::new();
|
|
||||||
let (entry1, _guard1) = default_entry();
|
|
||||||
let (entry2, _guard2) = default_entry();
|
|
||||||
state.append(entry1);
|
|
||||||
state.append(entry2);
|
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
|
|
||||||
assert_eq!(entries.len(), 2);
|
|
||||||
assert!(entries.contains_key(&0));
|
|
||||||
assert!(entries.contains_key(&1));
|
|
||||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
|
||||||
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
|
||||||
assert_eq!(batch.id, 0);
|
|
||||||
assert_eq!(batch.size, 2);
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 2);
|
|
||||||
assert_eq!(state.entries.len(), 0);
|
|
||||||
assert_eq!(state.next_batch_id, 1);
|
|
||||||
|
|
||||||
let (entry3, _guard3) = default_entry();
|
|
||||||
state.append(entry3);
|
|
||||||
|
|
||||||
assert!(state.next_batch(Some(2), 2).is_none());
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 3);
|
|
||||||
assert_eq!(state.entries.len(), 1);
|
|
||||||
let (id, _) = state.entries.remove(0).unwrap();
|
|
||||||
assert_eq!(id, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_next_batch_max_size() {
|
|
||||||
let mut state = State::new();
|
|
||||||
let (entry1, _guard1) = default_entry();
|
|
||||||
let (entry2, _guard2) = default_entry();
|
|
||||||
state.append(entry1);
|
|
||||||
state.append(entry2);
|
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
|
|
||||||
assert_eq!(entries.len(), 1);
|
|
||||||
assert!(entries.contains_key(&0));
|
|
||||||
assert_eq!(batch.id, 0);
|
|
||||||
assert_eq!(batch.size, 1);
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 2);
|
|
||||||
assert_eq!(state.entries.len(), 1);
|
|
||||||
assert_eq!(state.next_batch_id, 1);
|
|
||||||
|
|
||||||
let (entry3, _guard3) = default_entry();
|
|
||||||
state.append(entry3);
|
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
|
|
||||||
assert_eq!(entries.len(), 2);
|
|
||||||
assert!(entries.contains_key(&1));
|
|
||||||
assert!(entries.contains_key(&2));
|
|
||||||
assert_eq!(batch.id, 1);
|
|
||||||
assert_eq!(batch.size, 2);
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 3);
|
|
||||||
assert_eq!(state.entries.len(), 0);
|
|
||||||
assert_eq!(state.next_batch_id, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_queue_append() {
|
|
||||||
let queue = Queue::new();
|
|
||||||
let (entry, _guard) = default_entry();
|
|
||||||
queue.append(entry);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_queue_next_batch_empty() {
|
|
||||||
let queue = Queue::new();
|
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1).await.is_none());
|
|
||||||
assert!(queue.next_batch(Some(1), 1).await.is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_queue_next_batch_min_size() {
|
|
||||||
let queue = Queue::new();
|
|
||||||
let (entry1, _guard1) = default_entry();
|
|
||||||
let (entry2, _guard2) = default_entry();
|
|
||||||
queue.append(entry1);
|
|
||||||
queue.append(entry2);
|
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
|
|
||||||
assert_eq!(entries.len(), 2);
|
|
||||||
assert!(entries.contains_key(&0));
|
|
||||||
assert!(entries.contains_key(&1));
|
|
||||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
|
||||||
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
|
||||||
assert_eq!(batch.id, 0);
|
|
||||||
assert_eq!(batch.size, 2);
|
|
||||||
|
|
||||||
let (entry3, _guard3) = default_entry();
|
|
||||||
queue.append(entry3);
|
|
||||||
|
|
||||||
assert!(queue.next_batch(Some(2), 2).await.is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_queue_next_batch_max_size() {
|
|
||||||
let queue = Queue::new();
|
|
||||||
let (entry1, _guard1) = default_entry();
|
|
||||||
let (entry2, _guard2) = default_entry();
|
|
||||||
queue.append(entry1);
|
|
||||||
queue.append(entry2);
|
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
|
|
||||||
assert_eq!(entries.len(), 1);
|
|
||||||
assert!(entries.contains_key(&0));
|
|
||||||
assert_eq!(batch.id, 0);
|
|
||||||
assert_eq!(batch.size, 1);
|
|
||||||
|
|
||||||
let (entry3, _guard3) = default_entry();
|
|
||||||
queue.append(entry3);
|
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
|
|
||||||
assert_eq!(entries.len(), 2);
|
|
||||||
assert!(entries.contains_key(&1));
|
|
||||||
assert!(entries.contains_key(&2));
|
|
||||||
assert_eq!(batch.id, 1);
|
|
||||||
assert_eq!(batch.size, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
|
||||||
let queue = Queue::new();
|
|
||||||
let (entry, _) = default_entry();
|
|
||||||
queue.append(entry);
|
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1).await.is_none());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -505,6 +505,8 @@ pub async fn run(
|
|||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
|
max_batch_weight: Option<usize>,
|
||||||
|
max_prefill_weight: Option<usize>,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
tokenizer: Option<Tokenizer>,
|
tokenizer: Option<Tokenizer>,
|
||||||
@ -552,6 +554,15 @@ pub async fn run(
|
|||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
|
// If max batch weight is not set, infer from max batch size and max seq length
|
||||||
|
let max_batch_weight = max_batch_weight
|
||||||
|
.unwrap_or(max_batch_size * max_total_tokens);
|
||||||
|
let max_prefill_weight = max_prefill_weight.unwrap_or_default();
|
||||||
|
|
||||||
|
if max_total_tokens > max_batch_weight {
|
||||||
|
panic!("max_total_tokens cannot be greater than max_batch_weight");
|
||||||
|
}
|
||||||
|
|
||||||
// Create state
|
// Create state
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
validation_workers,
|
validation_workers,
|
||||||
@ -565,6 +576,8 @@ pub async fn run(
|
|||||||
client,
|
client,
|
||||||
validation,
|
validation,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
max_batch_weight,
|
||||||
|
max_prefill_weight,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
);
|
);
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
use std::cmp::min;
|
||||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
@ -69,7 +70,7 @@ impl Validation {
|
|||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: u32,
|
max_new_tokens: u32,
|
||||||
) -> Result<String, ValidationError> {
|
) -> Result<(String, usize), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some(sender) = &self.sender {
|
if let Some(sender) = &self.sender {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
@ -105,7 +106,7 @@ impl Validation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||||
Ok(inputs)
|
Ok((inputs, input_length))
|
||||||
}
|
}
|
||||||
// Return inputs without validation
|
// Return inputs without validation
|
||||||
else {
|
else {
|
||||||
@ -123,7 +124,7 @@ impl Validation {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(inputs)
|
Ok((inputs, min(truncate.unwrap_or(usize::MAX), self.max_input_length)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,7 +239,7 @@ impl Validation {
|
|||||||
.unwrap_or(Ok(None))?;
|
.unwrap_or(Ok(None))?;
|
||||||
|
|
||||||
// Validate inputs
|
// Validate inputs
|
||||||
let inputs = self
|
let (inputs, input_length) = self
|
||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
@ -262,7 +263,7 @@ impl Validation {
|
|||||||
|
|
||||||
Ok(ValidGenerateRequest {
|
Ok(ValidGenerateRequest {
|
||||||
inputs,
|
inputs,
|
||||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
truncate: input_length as u32,
|
||||||
parameters,
|
parameters,
|
||||||
stopping_parameters,
|
stopping_parameters,
|
||||||
})
|
})
|
||||||
|
Loading…
Reference in New Issue
Block a user