feat(router): add max_batch_size (#1542)

Some hardware require a maximum batch size.
This commit is contained in:
OlivierDehaene 2024-02-09 12:38:41 +01:00 committed by Karol Damaszke
parent 777e519277
commit 518d30dec4
9 changed files with 134 additions and 34 deletions

View File

@ -197,6 +197,14 @@ Options:
[env: MAX_WAITING_TOKENS=] [env: MAX_WAITING_TOKENS=]
[default: 20] [default: 20]
```
## MAX_BATCH_SIZE
```shell
--max-batch-size <MAX_BATCH_SIZE>
Enforce a maximum number of requests per batch Specific flag for hardware targets that do not support unpadded inference
[env: MAX_BATCH_SIZE=]
``` ```
## HOSTNAME ## HOSTNAME
```shell ```shell

View File

@ -281,6 +281,11 @@ struct Args {
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
/// Enforce a maximum number of requests per batch
/// Specific flag for hardware targets that do not support unpadded inference
#[clap(long, env)]
max_batch_size: Option<usize>,
/// The IP address to listen on /// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)] #[clap(default_value = "0.0.0.0", long, env)]
hostname: String, hostname: String,
@ -1056,6 +1061,12 @@ fn spawn_webserver(
router_args.push(max_batch_total_tokens.to_string()); router_args.push(max_batch_total_tokens.to_string());
} }
// Router optional max batch size
if let Some(max_batch_size) = args.max_batch_size {
router_args.push("--max-batch-size".to_string());
router_args.push(max_batch_size.to_string());
}
// Model optional revision // Model optional revision
if let Some(ref revision) = args.revision { if let Some(ref revision) = args.revision {
router_args.push("--revision".to_string()); router_args.push("--revision".to_string());

View File

@ -109,7 +109,7 @@ impl Client {
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: u32,
max_batch_total_tokens: Option<u32>, max_batch_size: Option<usize>,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true"); let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
if !warmup_enabled { if !warmup_enabled {
@ -142,17 +142,9 @@ impl Client {
} }
} }
// if max_batch_size is None, create two batches
let num_batches = max_batch_size.unwrap_or(2).min(2);
let mut id_counter: u64 = 0; let mut id_counter: u64 = 0;
let num_batches = match max_batch_total_tokens {
Some(val) => {
if val == max_total_tokens {
1
} else {
2
}
}
None => 2, // If max_batch_total_tokens is None, create two batches
};
for shape in shapes.iter() { for shape in shapes.iter() {
// create two batches in order to trigger concatenate operation // create two batches in order to trigger concatenate operation
// in case decode bs=1 create one batch // in case decode bs=1 create one batch

View File

@ -99,13 +99,18 @@ impl ShardedClient {
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: u32,
max_batch_total_tokens: Option<u32>, max_batch_size: Option<usize>,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| { .map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens, max_batch_total_tokens)) Box::pin(client.warmup(
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_size,
))
}) })
.collect(); .collect();
// Take the minimum value // Take the minimum value

View File

@ -63,6 +63,7 @@ impl Infer {
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>,
max_concurrent_requests: usize, max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
max_input_length: u32, max_input_length: u32,
@ -92,6 +93,7 @@ impl Infer {
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size,
queue.clone(), queue.clone(),
shared.clone(), shared.clone(),
generation_health, generation_health,
@ -349,6 +351,7 @@ async fn batching_task(
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>,
queue: Queue, queue: Queue,
shared: Arc<Shared>, shared: Arc<Shared>,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
@ -362,7 +365,12 @@ async fn batching_task(
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue // waiting in the queue
while let Some((mut entries, batch, span)) = queue while let Some((mut entries, batch, span)) = queue
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens) .next_batch(
None,
max_batch_size,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await .await
{ {
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
@ -390,10 +398,11 @@ async fn batching_task(
}; };
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_prefill_tokens, token_budget) .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
.await .await
{ {
// Tracking metrics // Tracking metrics

View File

@ -73,6 +73,8 @@ pub struct Info {
pub max_batch_total_tokens: u32, pub max_batch_total_tokens: u32,
#[schema(example = "20")] #[schema(example = "20")]
pub max_waiting_tokens: usize, pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
#[schema(example = "2")] #[schema(example = "2")]
pub validation_workers: usize, pub validation_workers: usize,
/// Router Info /// Router Info

View File

@ -48,6 +48,8 @@ struct Args {
max_batch_total_tokens: Option<u32>, max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)] #[clap(default_value = "0.0.0.0", long, env)]
hostname: String, hostname: String,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
@ -94,6 +96,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size,
hostname, hostname,
port, port,
master_shard_uds_path, master_shard_uds_path,
@ -138,6 +141,25 @@ async fn main() -> Result<(), RouterError> {
} }
} }
let (max_batch_size, max_batch_total_tokens) = match (max_batch_size, max_batch_total_tokens) {
(Some(_max_batch_size), Some(_max_batch_total_tokens)) => {
if (_max_batch_total_tokens as usize / max_total_tokens) != _max_batch_size {
tracing::warn!("max_batch_size was set to {_max_batch_size} while max_batch_total_tokens to {_max_batch_total_tokens}");
tracing::warn!("These values are not match, so max_batch_size will be preferred");
(Some(_max_batch_size), Some((_max_batch_size * max_total_tokens) as u32))
} else {
(Some(_max_batch_size), Some(_max_batch_total_tokens))
}
},
(Some(_max_batch_size), None) => (
Some(_max_batch_size), Some((_max_batch_size * max_total_tokens) as u32)
),
(None, Some(_max_batch_total_tokens)) => (
Some(_max_batch_total_tokens as usize / max_total_tokens), Some(_max_batch_total_tokens)
),
(None, None) => (None, None),
};
// CORS allowed origins // CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue // map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin // Finally, convert to AllowOrigin
@ -298,7 +320,7 @@ async fn main() -> Result<(), RouterError> {
max_input_length as u32, max_input_length as u32,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_total_tokens as u32, max_total_tokens as u32,
max_batch_total_tokens, max_batch_size,
) )
.await .await
.map_err(RouterError::Warmup)? .map_err(RouterError::Warmup)?
@ -355,6 +377,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_supported_batch_total_tokens, max_supported_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size,
sharded_client, sharded_client,
tokenizer, tokenizer,
validation_workers, validation_workers,

View File

@ -79,6 +79,7 @@ impl Queue {
pub(crate) async fn next_batch( pub(crate) async fn next_batch(
&self, &self,
min_size: Option<usize>, min_size: Option<usize>,
max_size: Option<usize>,
prefill_token_budget: u32, prefill_token_budget: u32,
token_budget: u32, token_budget: u32,
) -> Option<NextBatch> { ) -> Option<NextBatch> {
@ -89,6 +90,7 @@ impl Queue {
self.queue_sender self.queue_sender
.send(QueueCommand::NextBatch { .send(QueueCommand::NextBatch {
min_size, min_size,
max_size,
prefill_token_budget, prefill_token_budget,
token_budget, token_budget,
response_sender, response_sender,
@ -128,12 +130,14 @@ async fn queue_task(
} }
QueueCommand::NextBatch { QueueCommand::NextBatch {
min_size, min_size,
max_size,
prefill_token_budget, prefill_token_budget,
token_budget, token_budget,
response_sender, response_sender,
span, span,
} => span.in_scope(|| { } => span.in_scope(|| {
let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget); let next_batch =
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
response_sender.send(next_batch).unwrap(); response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64); metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}), }),
@ -308,6 +312,7 @@ impl State {
fn next_batch( fn next_batch(
&mut self, &mut self,
min_size: Option<usize>, min_size: Option<usize>,
max_size: Option<usize>,
prefill_token_budget: u32, prefill_token_budget: u32,
token_budget: u32, token_budget: u32,
) -> Option<NextBatch> { ) -> Option<NextBatch> {
@ -403,6 +408,11 @@ impl State {
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap // Insert in batch_entries IntMap
batch_entries.insert(id, entry); batch_entries.insert(id, entry);
// Check if max_size
if Some(batch_requests.len()) == max_size {
break;
}
} }
// Empty batch // Empty batch
@ -451,6 +461,7 @@ enum QueueCommand {
Append(Box<Entry>, Span), Append(Box<Entry>, Span),
NextBatch { NextBatch {
min_size: Option<usize>, min_size: Option<usize>,
max_size: Option<usize>,
prefill_token_budget: u32, prefill_token_budget: u32,
token_budget: u32, token_budget: u32,
response_sender: oneshot::Sender<Option<NextBatch>>, response_sender: oneshot::Sender<Option<NextBatch>>,
@ -535,8 +546,8 @@ mod tests {
fn test_next_batch_empty() { fn test_next_batch_empty() {
let mut state = default_state(); let mut state = default_state();
assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(None, None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1, 1).is_none()); assert!(state.next_batch(Some(1), None, 1, 1).is_none());
} }
#[test] #[test]
@ -547,7 +558,7 @@ mod tests {
state.append(entry1); state.append(entry1);
state.append(entry2); state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 2, 4).unwrap(); let (entries, batch, _) = state.next_batch(None, None, 2, 4).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
@ -563,7 +574,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
state.append(entry3); state.append(entry3);
assert!(state.next_batch(Some(2), 2, 2).is_none()); assert!(state.next_batch(Some(2), None, 2, 2).is_none());
assert_eq!(state.next_id, 3); assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1); assert_eq!(state.entries.len(), 1);
@ -571,6 +582,26 @@ mod tests {
assert_eq!(id, 2); assert_eq!(id, 2);
} }
#[test]
fn test_next_batch_max_size() {
let mut state = default_state();
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert!(entries.get(&0).unwrap().batch_time.is_some());
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 1);
assert_eq!(state.next_id, 2);
assert_eq!(state.entries.len(), 1);
assert_eq!(state.next_batch_id, 1);
}
#[test] #[test]
fn test_next_batch_token_budget() { fn test_next_batch_token_budget() {
let mut state = default_state(); let mut state = default_state();
@ -579,7 +610,7 @@ mod tests {
state.append(entry1); state.append(entry1);
state.append(entry2); state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 1, 2).unwrap(); let (entries, batch, _) = state.next_batch(None, None, 1, 2).unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
@ -592,7 +623,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
state.append(entry3); state.append(entry3);
let (entries, batch, _) = state.next_batch(None, 3, 6).unwrap(); let (entries, batch, _) = state.next_batch(None, None, 3, 6).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));
@ -615,8 +646,8 @@ mod tests {
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = default_queue(); let queue = default_queue();
assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
} }
#[tokio::test] #[tokio::test]
@ -627,7 +658,7 @@ mod tests {
queue.append(entry1); queue.append(entry1);
queue.append(entry2); queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, None, 2, 4).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
@ -640,11 +671,11 @@ mod tests {
queue.append(entry3); queue.append(entry3);
// Not enough requests pending // Not enough requests pending
assert!(queue.next_batch(Some(2), 2, 2).await.is_none()); assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
// Not enough token budget // Not enough token budget
assert!(queue.next_batch(Some(1), 0, 0).await.is_none()); assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
// Ok // Ok
let (entries2, batch2, _) = queue.next_batch(Some(1), 1, 2).await.unwrap(); let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 4).await.unwrap();
assert_eq!(entries2.len(), 1); assert_eq!(entries2.len(), 1);
assert!(entries2.contains_key(&2)); assert!(entries2.contains_key(&2));
assert!(entries2.get(&2).unwrap().batch_time.is_some()); assert!(entries2.get(&2).unwrap().batch_time.is_some());
@ -652,6 +683,22 @@ mod tests {
assert_eq!(batch2.size, 1); assert_eq!(batch2.size, 1);
} }
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = default_queue();
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert!(entries.get(&0).unwrap().batch_time.is_some());
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 1);
}
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = default_queue(); let queue = default_queue();
@ -660,7 +707,7 @@ mod tests {
queue.append(entry1); queue.append(entry1);
queue.append(entry2); queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 1, 2).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, None, 1, 2).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
@ -669,7 +716,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
queue.append(entry3); queue.append(entry3);
let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, None, 3, 6).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));
@ -686,9 +733,9 @@ mod tests {
queue.append(entry2); queue.append(entry2);
// Budget of 1 is not enough // Budget of 1 is not enough
assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
@ -702,6 +749,6 @@ mod tests {
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);
assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
} }
} }

View File

@ -770,6 +770,7 @@ pub async fn run(
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>,
client: ShardedClient, client: ShardedClient,
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,
validation_workers: usize, validation_workers: usize,
@ -851,6 +852,7 @@ pub async fn run(
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size,
max_concurrent_requests, max_concurrent_requests,
shard_info.requires_padding, shard_info.requires_padding,
max_input_length as u32, max_input_length as u32,
@ -934,6 +936,7 @@ pub async fn run(
waiting_served_ratio, waiting_served_ratio,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size,
validation_workers, validation_workers,
version: env!("CARGO_PKG_VERSION"), version: env!("CARGO_PKG_VERSION"),
sha: option_env!("VERGEN_GIT_SHA"), sha: option_env!("VERGEN_GIT_SHA"),