feat(router): use number of tokens in batch as input for dynamic batching (#226)

Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
OlivierDehaene 2023-04-24 17:59:00 +02:00 committed by GitHub
parent 98a3e0d135
commit ebc74d5666
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 399 additions and 172 deletions

View File

@ -39,8 +39,12 @@ struct Args {
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "1512", long, env)] #[clap(default_value = "1512", long, env)]
max_total_tokens: usize, max_total_tokens: usize,
#[clap(default_value = "32", long, env)] #[clap(long, env)]
max_batch_size: usize, max_batch_size: Option<usize>,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)]
max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
@ -93,6 +97,8 @@ fn main() -> ExitCode {
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
max_batch_size, max_batch_size,
max_batch_total_tokens,
waiting_served_ratio,
max_waiting_tokens, max_waiting_tokens,
port, port,
shard_uds_path, shard_uds_path,
@ -380,8 +386,8 @@ fn main() -> ExitCode {
max_input_length.to_string(), max_input_length.to_string(),
"--max-total-tokens".to_string(), "--max-total-tokens".to_string(),
max_total_tokens.to_string(), max_total_tokens.to_string(),
"--max-batch-size".to_string(), "--waiting-served-ratio".to_string(),
max_batch_size.to_string(), waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(), "--max-waiting-tokens".to_string(),
max_waiting_tokens.to_string(), max_waiting_tokens.to_string(),
"--port".to_string(), "--port".to_string(),
@ -392,6 +398,15 @@ fn main() -> ExitCode {
model_id, model_id,
]; ];
// Deprecate max_batch_size
if let Some(max_batch_size) = max_batch_size {
argv.push("--max-batch-size".to_string());
argv.push(max_batch_size.to_string())
} else {
argv.push("--max-batch-total-tokens".to_string());
argv.push(max_batch_total_tokens.to_string())
}
// Model optional revision // Model optional revision
if let Some(ref revision) = revision { if let Some(ref revision) = revision {
argv.push("--revision".to_string()); argv.push("--revision".to_string());

View File

@ -9,6 +9,8 @@ service TextGenerationService {
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
/// Empties batch cache /// Empties batch cache
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Prefill batch and decode first token /// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse); rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches /// Decode token for a list of prefilled batches
@ -89,6 +91,8 @@ message Batch {
repeated Request requests = 2; repeated Request requests = 2;
/// Batch size (==len(requests)) /// Batch size (==len(requests))
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
} }
enum FinishReason { enum FinishReason {
@ -134,6 +138,19 @@ message Generation {
GeneratedText generated_text = 7; GeneratedText generated_text = 7;
} }
message FilterBatchRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated Request keep_requests = 2;
}
message FilterBatchResponse {
/// Filtered Batch (cached)
Batch batch = 1;
}
message PrefillRequest { message PrefillRequest {
/// Batch /// Batch
Batch batch = 1; Batch batch = 1;

View File

@ -70,6 +70,22 @@ impl Client {
Ok(()) Ok(())
} }
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
keep_requests: Vec<Request>,
) -> Result<Option<Batch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
keep_requests,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
}
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch
/// ///
/// Returns Generation for each request in batch /// Returns Generation for each request in batch

View File

@ -1,6 +1,6 @@
/// Multi shard Client /// Multi shard Client
use crate::Result; use crate::Result;
use crate::{Batch, Client, Generation, ShardInfo}; use crate::{Batch, Client, Generation, Request, ShardInfo};
use futures::future::join_all; use futures::future::join_all;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
@ -59,6 +59,22 @@ impl ShardedClient {
join_all(futures).await.into_iter().collect() join_all(futures).await.into_iter().collect()
} }
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
keep_requests: Vec<Request>,
) -> Result<Option<Batch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch
/// ///
/// Returns Generation for each request in batch /// Returns Generation for each request in batch

View File

@ -39,12 +39,14 @@ impl Infer {
pub(crate) fn new( pub(crate) fn new(
client: ShardedClient, client: ShardedClient,
validation: Validation, validation: Validation,
max_batch_size: usize, waiting_served_ratio: f32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_concurrent_requests: usize, max_concurrent_requests: usize,
requires_padding: bool,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(); let queue = Queue::new(requires_padding);
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
batching_task: Notify::new(), batching_task: Notify::new(),
}); });
@ -52,7 +54,8 @@ impl Infer {
// Spawn batching background task that contains all the inference logic // Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task( tokio::spawn(batching_task(
client, client,
max_batch_size, waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
queue.clone(), queue.clone(),
shared.clone(), shared.clone(),
@ -232,18 +235,12 @@ impl Infer {
/// Batches requests and sends them to the inference server /// Batches requests and sends them to the inference server
async fn batching_task( async fn batching_task(
mut client: ShardedClient, mut client: ShardedClient,
max_batch_size: usize, waiting_served_ratio: f32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
queue: Queue, queue: Queue,
shared: Arc<Shared>, shared: Arc<Shared>,
) { ) {
// Minimum batch size after which we try to add more requests
let limit_min_batch_size = if max_batch_size > 1 {
(max_batch_size / 2) as u32
} else {
0
};
// Infinite loop // Infinite loop
loop { loop {
// Wait for a notification from the Infer struct // Wait for a notification from the Infer struct
@ -252,7 +249,9 @@ async fn batching_task(
// Get the next batch from the queue // Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue // waiting in the queue
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await { while let Some((mut entries, batch, span)) =
queue.next_batch(None, max_batch_total_tokens).await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries) let mut cached_batch = prefill(&mut client, batch, &mut entries)
.instrument(span) .instrument(span)
.await; .await;
@ -263,48 +262,57 @@ async fn batching_task(
while let Some(batch) = cached_batch { while let Some(batch) = cached_batch {
// Get current batch info // Get current batch info
let batch_size = batch.size; let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch]; let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64); metrics::gauge!("tgi_batch_current_size", batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
// If the current batch is too small, we try to add more requests to it let min_size = if waiting_tokens >= max_waiting_tokens {
if batch_size <= limit_min_batch_size { // If we didn't onboard any new requests since >= max_waiting_tokens, we try
let min_size = match waiting_tokens { // to add a new batch even though its size might be small
// If we didn't onboard any new requests since >= max_waiting_tokens, we try None
// to add a new batch even though its size might be small } else {
_ if waiting_tokens >= max_waiting_tokens => None, // Minimum batch size
// Minimum size criteria Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
_ => Some(limit_min_batch_size as usize), };
};
// Try to get a new batch let token_budget = max_batch_total_tokens - batch_max_tokens;
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_size - batch_size as usize)
.await
{
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache // Try to get a new batch
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) if let Some((mut new_entries, new_batch, span)) =
.instrument(span) queue.next_batch(min_size, token_budget).await
.await; {
// Reset waiting counter // Tracking metrics
waiting_tokens = 1; if min_size.is_some() {
// Extend current batch with the new batch metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
if let Some(new_cached_batch) = new_cached_batch { } else {
entries.extend(new_entries); metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
batches.push(new_cached_batch); }
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
} }
} }
// Create span for this batch to add context to inference calls // Create span for this batch to add context to inference calls
let next_batch_size = entries.len(); let next_batch_size = entries.len();
let next_batch_span = let next_batch_span =
@ -325,6 +333,7 @@ async fn batching_task(
waiting_tokens += 1; waiting_tokens += 1;
} }
metrics::gauge!("tgi_batch_current_size", 0.0); metrics::gauge!("tgi_batch_current_size", 0.0);
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
} }
} }
} }
@ -341,22 +350,11 @@ async fn prefill(
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = match next_batch { let next_batch = filter_batch(client, next_batch, entries).await;
None => None,
Some(batch) => {
let id = batch.id;
let next_batch = filter_batch(batch, entries);
// Next batch is now empty
// Clear it from the Python shards cache
if next_batch.is_none() {
let _ = client.clear_cache(Some(id)).await;
}
next_batch
}
};
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
@ -384,22 +382,11 @@ async fn decode(
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = match next_batch { let next_batch = filter_batch(client, next_batch, entries).await;
None => None,
Some(batch) => {
let id = batch.id;
let next_batch = filter_batch(batch, entries);
// Next batch is now empty
// Clear it from the Python shards cache
if next_batch.is_none() {
let _ = client.clear_cache(Some(id)).await;
}
next_batch
}
};
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
@ -419,14 +406,35 @@ async fn decode(
/// Filter a `batch` and remove all requests not present in `entries` /// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)] #[instrument(skip_all)]
fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch> { async fn filter_batch(
batch.requests.retain(|r| entries.contains_key(&r.id)); client: &mut ShardedClient,
let size = batch.requests.len(); next_batch: Option<Batch>,
if size == 0 { entries: &IntMap<u64, Entry>,
return None; ) -> Option<Batch> {
let mut batch = next_batch?;
// No need to filter
if batch.size as usize == entries.len() {
return Some(batch);
}
let id = batch.id;
// Retain only requests that are still in entries
batch.requests.retain(|r| entries.contains_key(&r.id));
if batch.requests.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.clear_cache(Some(id)).await.unwrap();
None
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.requests).await.unwrap()
} }
batch.size = size as u32;
Some(batch)
} }
/// Send one or multiple `InferStreamResponse` to Infer for all `entries` /// Send one or multiple `InferStreamResponse` to Infer for all `entries`

View File

@ -31,8 +31,12 @@ struct Args {
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "1512", long, env)] #[clap(default_value = "1512", long, env)]
max_total_tokens: usize, max_total_tokens: usize,
#[clap(default_value = "32", long, env)] #[clap(long, env)]
max_batch_size: usize, max_batch_size: Option<usize>,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)]
max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
@ -64,6 +68,8 @@ fn main() -> Result<(), std::io::Error> {
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
max_batch_size, max_batch_size,
waiting_served_ratio,
mut max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
port, port,
master_shard_uds_path, master_shard_uds_path,
@ -119,6 +125,12 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async { .block_on(async {
init_logging(otlp_endpoint, json_output); init_logging(otlp_endpoint, json_output);
if let Some(max_batch_size) = max_batch_size{
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
}
if tokenizer.is_none() { if tokenizer.is_none() {
tracing::warn!( tracing::warn!(
"Could not find a fast tokenizer implementation for {tokenizer_name}" "Could not find a fast tokenizer implementation for {tokenizer_name}"
@ -174,7 +186,8 @@ fn main() -> Result<(), std::io::Error> {
max_stop_sequences, max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
max_batch_size, waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
sharded_client, sharded_client,
tokenizer, tokenizer,

View File

@ -2,7 +2,6 @@ use crate::infer::InferError;
use crate::infer::InferStreamResponse; use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_client::{Batch, Request}; use text_generation_client::{Batch, Request};
use tokio::sync::oneshot; use tokio::sync::oneshot;
@ -34,12 +33,12 @@ pub(crate) struct Queue {
} }
impl Queue { impl Queue {
pub(crate) fn new() -> Self { pub(crate) fn new(requires_padding: bool) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = flume::unbounded(); let (queue_sender, queue_receiver) = flume::unbounded();
// Launch background queue task // Launch background queue task
tokio::spawn(queue_task(queue_receiver)); tokio::spawn(queue_task(requires_padding, queue_receiver));
Self { queue_sender } Self { queue_sender }
} }
@ -59,7 +58,7 @@ impl Queue {
pub(crate) async fn next_batch( pub(crate) async fn next_batch(
&self, &self,
min_size: Option<usize>, min_size: Option<usize>,
max_size: usize, token_budget: u32,
) -> Option<NextBatch> { ) -> Option<NextBatch> {
// Create response channel // Create response channel
let (response_sender, response_receiver) = oneshot::channel(); let (response_sender, response_receiver) = oneshot::channel();
@ -68,7 +67,7 @@ impl Queue {
self.queue_sender self.queue_sender
.send(QueueCommand::NextBatch { .send(QueueCommand::NextBatch {
min_size, min_size,
max_size, token_budget,
response_sender, response_sender,
span: Span::current(), span: Span::current(),
}) })
@ -80,20 +79,24 @@ impl Queue {
} }
// Background task responsible of the queue state // Background task responsible of the queue state
async fn queue_task(receiver: flume::Receiver<QueueCommand>) { async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueCommand>) {
let mut state = State::new(); let mut state = State::new(requires_padding);
while let Ok(cmd) = receiver.recv_async().await { while let Ok(cmd) = receiver.recv_async().await {
match cmd { match cmd {
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)), QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(entry));
metrics::increment_gauge!("tgi_queue_size", 1.0);
}
QueueCommand::NextBatch { QueueCommand::NextBatch {
min_size, min_size,
max_size, token_budget,
response_sender, response_sender,
span, span,
} => span.in_scope(|| { } => span.in_scope(|| {
let next_batch = state.next_batch(min_size, max_size); let next_batch = state.next_batch(min_size, token_budget);
response_sender.send(next_batch).unwrap_or(()); response_sender.send(next_batch).unwrap_or(());
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}), }),
} }
} }
@ -110,14 +113,18 @@ struct State {
/// Id of the next batch /// Id of the next batch
next_batch_id: u64, next_batch_id: u64,
/// Whether the model is using padding
requires_padding: bool,
} }
impl State { impl State {
fn new() -> Self { fn new(requires_padding: bool) -> Self {
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
requires_padding,
} }
} }
@ -130,11 +137,10 @@ impl State {
// Push entry in the queue // Push entry in the queue
self.entries.push_back((self.next_id, entry)); self.entries.push_back((self.next_id, entry));
self.next_id += 1; self.next_id += 1;
metrics::increment_gauge!("tgi_queue_size", 1.0);
} }
// Get the next batch // Get the next batch
fn next_batch(&mut self, min_size: Option<usize>, max_size: usize) -> Option<NextBatch> { fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
if self.entries.is_empty() { if self.entries.is_empty() {
return None; return None;
} }
@ -146,17 +152,19 @@ impl State {
} }
} }
let max_batch_size = min(self.entries.len(), max_size);
// Create span for this batch to add context to inference calls // Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current()); next_batch_span.follows_from(&Span::current());
let mut batch_requests = Vec::with_capacity(max_batch_size); let mut batch_requests = Vec::with_capacity(self.entries.len());
let mut batch_entries = let mut batch_entries =
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default()); IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
// Iterate on buffer let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0;
// Pop entries starting from the front of the queue
while let Some((id, mut entry)) = self.entries.pop_front() { while let Some((id, mut entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request // Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client) // was dropped by the client)
@ -165,6 +173,24 @@ impl State {
continue; continue;
} }
if self.requires_padding {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else {
prefill_tokens += entry.request.input_length;
}
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
if (prefill_tokens + decode_tokens) > token_budget {
// Entry is over budget
// Add it back to the front
self.entries.push_front((id, entry));
break;
}
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer"); let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships // Add relationships
@ -184,21 +210,29 @@ impl State {
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap // Insert in batch_entries IntMap
batch_entries.insert(id, entry); batch_entries.insert(id, entry);
if batch_requests.len() == max_batch_size {
// We have enough requests in the batch
break;
}
} }
metrics::gauge!("tgi_queue_size", self.entries.len() as f64); // Empty batch
// Maybe all entries were dropped because their channel were closed
if batch_requests.is_empty() { if batch_requests.is_empty() {
return None; return None;
} }
// Final batch size once we dropped entries // Check if our batch is big enough
if let Some(min_size) = min_size {
// Batch is too small
if batch_requests.len() < min_size {
// Add back entries to the queue in the correct order
for r in batch_requests.into_iter().rev() {
let id = r.id;
let entry = batch_entries.remove(&id).unwrap();
self.entries.push_front((id, entry));
}
return None;
}
}
// Final batch size
let size = batch_requests.len() as u32; let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size); next_batch_span.record("batch_size", size);
@ -206,11 +240,13 @@ impl State {
id: self.next_batch_id, id: self.next_batch_id,
requests: batch_requests, requests: batch_requests,
size, size,
max_tokens: (prefill_tokens + decode_tokens),
}; };
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
metrics::histogram!("tgi_batch_next_size", batch.size as f64); metrics::histogram!("tgi_batch_next_size", batch.size as f64);
Some((batch_entries, batch, next_batch_span)) Some((batch_entries, batch, next_batch_span))
} }
} }
@ -222,7 +258,7 @@ enum QueueCommand {
Append(Entry, Span), Append(Entry, Span),
NextBatch { NextBatch {
min_size: Option<usize>, min_size: Option<usize>,
max_size: usize, token_budget: u32,
response_sender: oneshot::Sender<Option<NextBatch>>, response_sender: oneshot::Sender<Option<NextBatch>>,
span: Span, span: Span,
}, },
@ -243,6 +279,7 @@ mod tests {
let entry = Entry { let entry = Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: "".to_string(), inputs: "".to_string(),
input_length: 0,
truncate: 0, truncate: 0,
parameters: NextTokenChooserParameters { parameters: NextTokenChooserParameters {
temperature: 0.0, temperature: 0.0,
@ -256,7 +293,7 @@ mod tests {
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false, ignore_eos_token: false,
max_new_tokens: 0, max_new_tokens: 1,
stop_sequences: vec![], stop_sequences: vec![],
}, },
}, },
@ -271,7 +308,7 @@ mod tests {
#[test] #[test]
fn test_append() { fn test_append() {
let mut state = State::new(); let mut state = State::new(false);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
@ -287,7 +324,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_empty() { fn test_next_batch_empty() {
let mut state = State::new(); let mut state = State::new(false);
assert!(state.next_batch(None, 1).is_none()); assert!(state.next_batch(None, 1).is_none());
assert!(state.next_batch(Some(1), 1).is_none()); assert!(state.next_batch(Some(1), 1).is_none());
@ -295,7 +332,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_min_size() { fn test_next_batch_min_size() {
let mut state = State::new(); let mut state = State::new(false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -326,8 +363,8 @@ mod tests {
} }
#[test] #[test]
fn test_next_batch_max_size() { fn test_next_batch_token_budget() {
let mut state = State::new(); let mut state = State::new(false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -360,14 +397,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(); let queue = Queue::new(false);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(); let queue = Queue::new(false);
assert!(queue.next_batch(None, 1).await.is_none()); assert!(queue.next_batch(None, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1).await.is_none()); assert!(queue.next_batch(Some(1), 1).await.is_none());
@ -375,7 +412,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(); let queue = Queue::new(false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -397,8 +434,8 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(); let queue = Queue::new(false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -423,7 +460,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(); let queue = Queue::new(false);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);

View File

@ -511,7 +511,8 @@ pub async fn run(
max_stop_sequences: usize, max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
max_batch_size: usize, waiting_served_ratio: f32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
client: ShardedClient, client: ShardedClient,
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,
@ -571,9 +572,11 @@ pub async fn run(
let infer = Infer::new( let infer = Infer::new(
client, client,
validation, validation,
max_batch_size, waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_concurrent_requests, max_concurrent_requests,
shard_info.requires_padding,
); );
// Duration buckets // Duration buckets
@ -604,7 +607,7 @@ pub async fn run(
.collect(); .collect();
// Batch size buckets // Batch size buckets
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size")); let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
let batch_size_buckets: Vec<f64> = (0..max_batch_size).map(|x| (x + 1) as f64).collect(); let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
// Prometheus handler // Prometheus handler
let builder = PrometheusBuilder::new() let builder = PrometheusBuilder::new()

View File

@ -69,7 +69,7 @@ impl Validation {
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: u32, max_new_tokens: u32,
) -> Result<String, ValidationError> { ) -> Result<(String, usize), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some(sender) = &self.sender { if let Some(sender) = &self.sender {
// Create response channel // Create response channel
@ -105,25 +105,24 @@ impl Validation {
} }
metrics::histogram!("tgi_request_input_length", input_length as f64); metrics::histogram!("tgi_request_input_length", input_length as f64);
Ok(inputs) Ok((inputs, input_length))
} }
// Return inputs without validation // Return inputs without validation
else { else {
// In this case, we don't know the real length in tokens of the inputs // In this case, we don't know the real length in tokens of the inputs
// However, the inputs will be truncated by the python servers // However, the inputs will be truncated by the python servers
// We make sure that truncate + max_new_tokens <= self.max_total_tokens // We make sure that truncate + max_new_tokens <= self.max_total_tokens
let input_length = truncate.unwrap_or(self.max_input_length);
// Validate MaxNewTokens // Validate MaxNewTokens
if (truncate.unwrap_or(self.max_input_length) as u32 + max_new_tokens) if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
> self.max_total_tokens as u32
{
return Err(ValidationError::MaxNewTokens( return Err(ValidationError::MaxNewTokens(
self.max_total_tokens - self.max_input_length, self.max_total_tokens - self.max_input_length,
max_new_tokens, max_new_tokens,
)); ));
} }
Ok(inputs) Ok((inputs, input_length))
} }
} }
@ -238,7 +237,7 @@ impl Validation {
.unwrap_or(Ok(None))?; .unwrap_or(Ok(None))?;
// Validate inputs // Validate inputs
let inputs = self let (inputs, input_length) = self
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
@ -262,6 +261,7 @@ impl Validation {
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters, parameters,
stopping_parameters, stopping_parameters,
@ -333,6 +333,7 @@ type TokenizerRequest = (
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct ValidGenerateRequest { pub(crate) struct ValidGenerateRequest {
pub inputs: String, pub inputs: String,
pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub parameters: NextTokenChooserParameters, pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters, pub stopping_parameters: StoppingCriteriaParameters,

View File

@ -181,9 +181,7 @@ def test_causal_lm_generate_token_completion_multi(
next_batch = next_batch.filter([next_batch.requests[0]]) next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range( for _ in range(
stopping_criterias[0].max_new_tokens stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
- stopping_criterias[1].max_new_tokens
- 1
): ):
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)

View File

@ -174,14 +174,14 @@ def test_causal_lm_generate_token_completion_multi(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
) )
# Copy stopping_criterias before filtering # Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy() stopping_criterias = (
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
)
next_batch = next_batch.filter([next_batch.requests[0]]) next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range( for _ in range(
stopping_criterias[0].max_new_tokens stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
- stopping_criterias[1].max_new_tokens
- 1
): ):
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)

View File

@ -46,6 +46,9 @@ class CausalLMBatch(Batch):
max_input_length: int max_input_length: int
padding_right_offset: int padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
# Past metadata # Past metadata
keys_head_dim_last: bool = True keys_head_dim_last: bool = True
@ -54,6 +57,7 @@ class CausalLMBatch(Batch):
id=self.batch_id, id=self.batch_id,
requests=self.requests, requests=self.requests,
size=len(self), size=len(self),
max_tokens=self.max_tokens,
) )
@classmethod @classmethod
@ -73,6 +77,7 @@ class CausalLMBatch(Batch):
# Parse batch # Parse batch
max_truncation = 0 max_truncation = 0
padding_right_offset = 0 padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
@ -84,6 +89,7 @@ class CausalLMBatch(Batch):
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max( padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens padding_right_offset, stopping_criteria.max_new_tokens
) )
@ -112,6 +118,8 @@ class CausalLMBatch(Batch):
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
@ -128,6 +136,7 @@ class CausalLMBatch(Batch):
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -150,6 +159,7 @@ class CausalLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0 new_padding_right_offset = 0
for i, r in enumerate(requests): for i, r in enumerate(requests):
@ -168,19 +178,23 @@ class CausalLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
remaining_decode_tokens = (
new_padding_right_offset = max(
new_padding_right_offset,
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices] input_ids = self.input_ids[keep_indices]
position_ids = self.position_ids[keep_indices] position_ids = self.position_ids[keep_indices]
self.attention_mask = self.attention_mask[ self.attention_mask = self.attention_mask[
keep_indices, keep_indices,
-(self.padding_right_offset + max_input_length): -(self.padding_right_offset + max_input_length) : (
(self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset, self.attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
] ]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
@ -203,6 +217,8 @@ class CausalLMBatch(Batch):
layer[1] = past_values[keep_indices, :, -past_kv_length:, :] layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values del past_values
max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens
self.requests = requests self.requests = requests
self.requests_idx_mapping = requests_idx_mapping self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids self.input_ids = input_ids
@ -215,6 +231,7 @@ class CausalLMBatch(Batch):
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
return self return self
@ -239,6 +256,7 @@ class CausalLMBatch(Batch):
all_input_ids = [] all_input_ids = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
max_tokens = 0
# Batch tensors # Batch tensors
input_ids = None input_ids = None
@ -314,7 +332,8 @@ class CausalLMBatch(Batch):
# And ensure that we can update tensors in-place # And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple: if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [ batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
] ]
elif batch.past_key_values[0][0].shape == 3: elif batch.past_key_values[0][0].shape == 3:
for layer in batch.past_key_values: for layer in batch.past_key_values:
@ -322,6 +341,10 @@ class CausalLMBatch(Batch):
layer[k] = t.view(len(batch), -1, *t.shape[-2:]) layer[k] = t.view(len(batch), -1, *t.shape[-2:])
start_index = end_index start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
first_past_kvs = batches[0].past_key_values first_past_kvs = batches[0].past_key_values
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
@ -371,7 +394,9 @@ class CausalLMBatch(Batch):
start_index = end_index start_index = end_index
padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape) padded_past_values = first_past_kvs[j][1].new_zeros(
padded_past_values_shape
)
start_index = 0 start_index = 0
for batch in batches: for batch in batches:
past_values = batch.past_key_values[j][1] past_values = batch.past_key_values[j][1]
@ -387,6 +412,7 @@ class CausalLMBatch(Batch):
] = past_values[:, :, -past_seq_len:, :] ] = past_values[:, :, -past_seq_len:, :]
del past_values del past_values
# Update values
start_index = end_index start_index = end_index
past_key_values.append([padded_past_keys, padded_past_values]) past_key_values.append([padded_past_keys, padded_past_values])
@ -408,6 +434,7 @@ class CausalLMBatch(Batch):
max_input_length=max_input_length, max_input_length=max_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
) )
def __len__(self): def __len__(self):

View File

@ -56,9 +56,15 @@ class FlashCausalLMBatch(Batch):
# Constant shared tensor, ref here just so that it's accessible in concatentate() # Constant shared tensor, ref here just so that it's accessible in concatentate()
past_pad: Optional[torch.Tensor] past_pad: Optional[torch.Tensor]
# Maximum number of tokens this batch will grow to
max_tokens: int
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, requests=self.requests, size=len(self) id=self.batch_id,
requests=self.requests,
size=len(self),
max_tokens=self.max_tokens,
) )
@classmethod @classmethod
@ -86,6 +92,8 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
max_tokens = 0
# Parse batch # Parse batch
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
# request id -> idx in list mapping # request id -> idx in list mapping
@ -115,16 +123,20 @@ class FlashCausalLMBatch(Batch):
cu_seqlens.append(cumulative_length + input_length) cu_seqlens.append(cumulative_length + input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
all_input_ids_tensor.append( all_input_ids_tensor.append(
F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens)) F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens))
) )
# Update # Update
cumulative_length += input_length cumulative_length += input_length
max_tokens += input_length + max_new_tokens
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -143,6 +155,7 @@ class FlashCausalLMBatch(Batch):
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
past_pad=None, past_pad=None,
max_tokens=max_tokens,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -177,6 +190,8 @@ class FlashCausalLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
max_tokens = 0
for i, r in enumerate(requests): for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id] idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
@ -203,9 +218,14 @@ class FlashCausalLMBatch(Batch):
token_offsets.append(self.token_offsets[idx]) token_offsets.append(self.token_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
cumulative_length += request_input_length cumulative_length += request_input_length
max_tokens += request_input_length + (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
if single_request: if single_request:
# Preallocate tensor for bs = 1 case # Preallocate tensor for bs = 1 case
@ -241,6 +261,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
) )
@classmethod @classmethod
@ -269,6 +290,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_batch_size = 0 cumulative_batch_size = 0
cumulative_length = 0 cumulative_length = 0
max_tokens = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
requests.extend(batch.requests) requests.extend(batch.requests)
@ -310,6 +332,7 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_length += batch.cu_seqlens[-1] cumulative_length += batch.cu_seqlens[-1]
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens
return FlashCausalLMBatch( return FlashCausalLMBatch(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
@ -328,6 +351,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
) )
def __len__(self): def __len__(self):

View File

@ -101,6 +101,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
# Parse batch # Parse batch
max_truncation = 0 max_truncation = 0
padding_right_offset = 0 padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
@ -113,6 +114,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max( padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens padding_right_offset, stopping_criteria.max_new_tokens
) )
@ -141,6 +143,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
@ -157,6 +161,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
) )

View File

@ -54,10 +54,16 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length: int max_decoder_input_length: int
padding_right_offset: int padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf""" """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, requests=self.requests, size=len(self) id=self.batch_id,
requests=self.requests,
size=len(self),
max_tokens=self.max_tokens,
) )
@classmethod @classmethod
@ -80,6 +86,7 @@ class Seq2SeqLMBatch(Batch):
# Parse batch # Parse batch
max_truncation = 0 max_truncation = 0
padding_right_offset = 0 padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
inputs.append(r.inputs) inputs.append(r.inputs)
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
@ -92,6 +99,7 @@ class Seq2SeqLMBatch(Batch):
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max( padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens padding_right_offset, stopping_criteria.max_new_tokens
) )
@ -117,6 +125,8 @@ class Seq2SeqLMBatch(Batch):
) )
all_decoder_input_ids = decoder_input_ids.view(-1).split(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
@ -137,6 +147,7 @@ class Seq2SeqLMBatch(Batch):
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
max_decoder_input_length=1, max_decoder_input_length=1,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -166,6 +177,8 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length = 0 max_decoder_input_length = 0
padding_right_offset = 0 padding_right_offset = 0
remaining_decode_tokens = 0
for i, r in enumerate(requests): for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id] idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
@ -187,27 +200,38 @@ class Seq2SeqLMBatch(Batch):
) )
padding_right_offset = max( padding_right_offset = max(
padding_right_offset, padding_right_offset,
self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens self.stopping_criterias[idx].max_new_tokens
- self.stopping_criterias[idx].current_tokens,
) )
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx]) stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
remaining_decode_tokens += (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
self.decoder_input_ids = self.decoder_input_ids[keep_indices] self.decoder_input_ids = self.decoder_input_ids[keep_indices]
self.attention_mask = self.attention_mask[keep_indices, -max_input_length:] self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
if self.decoder_attention_mask is not None: if self.decoder_attention_mask is not None:
self.decoder_attention_mask = self.decoder_attention_mask[ self.decoder_attention_mask = self.decoder_attention_mask[
keep_indices, keep_indices,
-(self.padding_right_offset + max_decoder_input_length): -(self.padding_right_offset + max_decoder_input_length) : (
(self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset, self.decoder_attention_mask.shape[1] - self.padding_right_offset
)
+ padding_right_offset,
] ]
self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:] self.encoder_last_hidden_state = self.encoder_last_hidden_state[
keep_indices, -max_input_length:
]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple: if type(self.past_key_values[0]) == tuple:
self.past_key_values = [[t for t in layer] for layer in self.past_key_values] self.past_key_values = [
[t for t in layer] for layer in self.past_key_values
]
decoder_past_seq_len = max_decoder_input_length - 1 decoder_past_seq_len = max_decoder_input_length - 1
for layer in self.past_key_values: for layer in self.past_key_values:
@ -216,6 +240,11 @@ class Seq2SeqLMBatch(Batch):
layer[2] = layer[2][keep_indices, :, -max_input_length:] layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:] layer[3] = layer[3][keep_indices, :, -max_input_length:]
max_tokens = (
len(requests) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens
)
self.requests = requests self.requests = requests
self.requests_idx_mapping = requests_idx_mapping self.requests_idx_mapping = requests_idx_mapping
self.input_ids = None self.input_ids = None
@ -229,10 +258,10 @@ class Seq2SeqLMBatch(Batch):
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset self.padding_right_offset = padding_right_offset
self.max_tokens = max_tokens
return self return self
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
@ -261,6 +290,7 @@ class Seq2SeqLMBatch(Batch):
token_offsets = [] token_offsets = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
max_tokens = 0
# Batch tensors # Batch tensors
attention_mask = None attention_mask = None
@ -363,9 +393,18 @@ class Seq2SeqLMBatch(Batch):
# Ensure that we can update tensors in-place # Ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple: if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values] batch.past_key_values = [
[t for t in layer] for layer in batch.past_key_values
]
start_index = end_index start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length
- batch.max_input_length
+ max_decoder_input_length
- batch.max_decoder_input_length
) * len(batch)
# Determine shapes for new past kv tensors # Determine shapes for new past kv tensors
first_past_kvs = batches[0].past_key_values first_past_kvs = batches[0].past_key_values
@ -404,9 +443,9 @@ class Seq2SeqLMBatch(Batch):
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
past_seq_len = batch.max_decoder_input_length - 1 past_seq_len = batch.max_decoder_input_length - 1
padded_past_values[ padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
start_index:end_index, :, -past_seq_len:, : :, :, -past_seq_len:, :
] = t[:, :, -past_seq_len:, :] ]
del t del t
start_index = end_index start_index = end_index
@ -426,8 +465,8 @@ class Seq2SeqLMBatch(Batch):
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
padded_past_values[ padded_past_values[
start_index:end_index, :, -batch.max_input_length:, : start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length:, :] ] = t[:, :, -batch.max_input_length :, :]
del t del t
start_index = end_index start_index = end_index
@ -452,6 +491,7 @@ class Seq2SeqLMBatch(Batch):
max_input_length=max_input_length, max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
) )
def __len__(self): def __len__(self):

View File

@ -41,6 +41,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return generate_pb2.ClearCacheResponse() return generate_pb2.ClearCacheResponse()
async def FilterBatch(self, request, context):
batch = self.cache.pop(request.batch_id)
if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.keep_requests)
self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Prefill(self, request, context): async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device request.batch, self.model.tokenizer, self.model.device
@ -63,9 +72,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.cache.pop(batch_pb.id) batch = self.cache.pop(batch_pb.id)
if batch is None: if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batch = batch.filter(batch_pb.requests) batches.append(batch)
if batch is not None:
batches.append(batch)
if len(batches) == 0: if len(batches) == 0:
raise ValueError("All batches are empty") raise ValueError("All batches are empty")