mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
feat(router): add max_batch_size (#1542)
Some hardware require a maximum batch size.
This commit is contained in:
parent
777e519277
commit
518d30dec4
@ -197,6 +197,14 @@ Options:
|
|||||||
[env: MAX_WAITING_TOKENS=]
|
[env: MAX_WAITING_TOKENS=]
|
||||||
[default: 20]
|
[default: 20]
|
||||||
|
|
||||||
|
```
|
||||||
|
## MAX_BATCH_SIZE
|
||||||
|
```shell
|
||||||
|
--max-batch-size <MAX_BATCH_SIZE>
|
||||||
|
Enforce a maximum number of requests per batch Specific flag for hardware targets that do not support unpadded inference
|
||||||
|
|
||||||
|
[env: MAX_BATCH_SIZE=]
|
||||||
|
|
||||||
```
|
```
|
||||||
## HOSTNAME
|
## HOSTNAME
|
||||||
```shell
|
```shell
|
||||||
|
@ -281,6 +281,11 @@ struct Args {
|
|||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
|
|
||||||
|
/// Enforce a maximum number of requests per batch
|
||||||
|
/// Specific flag for hardware targets that do not support unpadded inference
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
|
||||||
/// The IP address to listen on
|
/// The IP address to listen on
|
||||||
#[clap(default_value = "0.0.0.0", long, env)]
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
hostname: String,
|
hostname: String,
|
||||||
@ -1056,6 +1061,12 @@ fn spawn_webserver(
|
|||||||
router_args.push(max_batch_total_tokens.to_string());
|
router_args.push(max_batch_total_tokens.to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Router optional max batch size
|
||||||
|
if let Some(max_batch_size) = args.max_batch_size {
|
||||||
|
router_args.push("--max-batch-size".to_string());
|
||||||
|
router_args.push(max_batch_size.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(ref revision) = args.revision {
|
if let Some(ref revision) = args.revision {
|
||||||
router_args.push("--revision".to_string());
|
router_args.push("--revision".to_string());
|
||||||
|
@ -109,7 +109,7 @@ impl Client {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
max_batch_total_tokens: Option<u32>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
|
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
|
||||||
if !warmup_enabled {
|
if !warmup_enabled {
|
||||||
@ -142,17 +142,9 @@ impl Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if max_batch_size is None, create two batches
|
||||||
|
let num_batches = max_batch_size.unwrap_or(2).min(2);
|
||||||
let mut id_counter: u64 = 0;
|
let mut id_counter: u64 = 0;
|
||||||
let num_batches = match max_batch_total_tokens {
|
|
||||||
Some(val) => {
|
|
||||||
if val == max_total_tokens {
|
|
||||||
1
|
|
||||||
} else {
|
|
||||||
2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => 2, // If max_batch_total_tokens is None, create two batches
|
|
||||||
};
|
|
||||||
for shape in shapes.iter() {
|
for shape in shapes.iter() {
|
||||||
// create two batches in order to trigger concatenate operation
|
// create two batches in order to trigger concatenate operation
|
||||||
// in case decode bs=1 create one batch
|
// in case decode bs=1 create one batch
|
||||||
|
@ -99,13 +99,18 @@ impl ShardedClient {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
max_batch_total_tokens: Option<u32>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| {
|
.map(|client| {
|
||||||
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens, max_batch_total_tokens))
|
Box::pin(client.warmup(
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
))
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
// Take the minimum value
|
// Take the minimum value
|
||||||
|
@ -63,6 +63,7 @@ impl Infer {
|
|||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
@ -92,6 +93,7 @@ impl Infer {
|
|||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
shared.clone(),
|
shared.clone(),
|
||||||
generation_health,
|
generation_health,
|
||||||
@ -349,6 +351,7 @@ async fn batching_task(
|
|||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
@ -362,7 +365,12 @@ async fn batching_task(
|
|||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the queue
|
// waiting in the queue
|
||||||
while let Some((mut entries, batch, span)) = queue
|
while let Some((mut entries, batch, span)) = queue
|
||||||
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
|
.next_batch(
|
||||||
|
None,
|
||||||
|
max_batch_size,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
||||||
@ -390,10 +398,11 @@ async fn batching_task(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
.next_batch(min_size, max_batch_prefill_tokens, token_budget)
|
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
|
@ -73,6 +73,8 @@ pub struct Info {
|
|||||||
pub max_batch_total_tokens: u32,
|
pub max_batch_total_tokens: u32,
|
||||||
#[schema(example = "20")]
|
#[schema(example = "20")]
|
||||||
pub max_waiting_tokens: usize,
|
pub max_waiting_tokens: usize,
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub max_batch_size: Option<usize>,
|
||||||
#[schema(example = "2")]
|
#[schema(example = "2")]
|
||||||
pub validation_workers: usize,
|
pub validation_workers: usize,
|
||||||
/// Router Info
|
/// Router Info
|
||||||
|
@ -48,6 +48,8 @@ struct Args {
|
|||||||
max_batch_total_tokens: Option<u32>,
|
max_batch_total_tokens: Option<u32>,
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
#[clap(default_value = "0.0.0.0", long, env)]
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
hostname: String,
|
hostname: String,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
@ -94,6 +96,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
hostname,
|
hostname,
|
||||||
port,
|
port,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
@ -138,6 +141,25 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let (max_batch_size, max_batch_total_tokens) = match (max_batch_size, max_batch_total_tokens) {
|
||||||
|
(Some(_max_batch_size), Some(_max_batch_total_tokens)) => {
|
||||||
|
if (_max_batch_total_tokens as usize / max_total_tokens) != _max_batch_size {
|
||||||
|
tracing::warn!("max_batch_size was set to {_max_batch_size} while max_batch_total_tokens to {_max_batch_total_tokens}");
|
||||||
|
tracing::warn!("These values are not match, so max_batch_size will be preferred");
|
||||||
|
(Some(_max_batch_size), Some((_max_batch_size * max_total_tokens) as u32))
|
||||||
|
} else {
|
||||||
|
(Some(_max_batch_size), Some(_max_batch_total_tokens))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
(Some(_max_batch_size), None) => (
|
||||||
|
Some(_max_batch_size), Some((_max_batch_size * max_total_tokens) as u32)
|
||||||
|
),
|
||||||
|
(None, Some(_max_batch_total_tokens)) => (
|
||||||
|
Some(_max_batch_total_tokens as usize / max_total_tokens), Some(_max_batch_total_tokens)
|
||||||
|
),
|
||||||
|
(None, None) => (None, None),
|
||||||
|
};
|
||||||
|
|
||||||
// CORS allowed origins
|
// CORS allowed origins
|
||||||
// map to go inside the option and then map to parse from String to HeaderValue
|
// map to go inside the option and then map to parse from String to HeaderValue
|
||||||
// Finally, convert to AllowOrigin
|
// Finally, convert to AllowOrigin
|
||||||
@ -298,7 +320,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_input_length as u32,
|
max_input_length as u32,
|
||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_total_tokens as u32,
|
max_total_tokens as u32,
|
||||||
max_batch_total_tokens,
|
max_batch_size,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(RouterError::Warmup)?
|
.map_err(RouterError::Warmup)?
|
||||||
@ -355,6 +377,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_supported_batch_total_tokens,
|
max_supported_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
sharded_client,
|
sharded_client,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
|
@ -79,6 +79,7 @@ impl Queue {
|
|||||||
pub(crate) async fn next_batch(
|
pub(crate) async fn next_batch(
|
||||||
&self,
|
&self,
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
|
max_size: Option<usize>,
|
||||||
prefill_token_budget: u32,
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
) -> Option<NextBatch> {
|
) -> Option<NextBatch> {
|
||||||
@ -89,6 +90,7 @@ impl Queue {
|
|||||||
self.queue_sender
|
self.queue_sender
|
||||||
.send(QueueCommand::NextBatch {
|
.send(QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
|
max_size,
|
||||||
prefill_token_budget,
|
prefill_token_budget,
|
||||||
token_budget,
|
token_budget,
|
||||||
response_sender,
|
response_sender,
|
||||||
@ -128,12 +130,14 @@ async fn queue_task(
|
|||||||
}
|
}
|
||||||
QueueCommand::NextBatch {
|
QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
|
max_size,
|
||||||
prefill_token_budget,
|
prefill_token_budget,
|
||||||
token_budget,
|
token_budget,
|
||||||
response_sender,
|
response_sender,
|
||||||
span,
|
span,
|
||||||
} => span.in_scope(|| {
|
} => span.in_scope(|| {
|
||||||
let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget);
|
let next_batch =
|
||||||
|
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
||||||
response_sender.send(next_batch).unwrap();
|
response_sender.send(next_batch).unwrap();
|
||||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||||
}),
|
}),
|
||||||
@ -308,6 +312,7 @@ impl State {
|
|||||||
fn next_batch(
|
fn next_batch(
|
||||||
&mut self,
|
&mut self,
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
|
max_size: Option<usize>,
|
||||||
prefill_token_budget: u32,
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
) -> Option<NextBatch> {
|
) -> Option<NextBatch> {
|
||||||
@ -403,6 +408,11 @@ impl State {
|
|||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
// Insert in batch_entries IntMap
|
// Insert in batch_entries IntMap
|
||||||
batch_entries.insert(id, entry);
|
batch_entries.insert(id, entry);
|
||||||
|
|
||||||
|
// Check if max_size
|
||||||
|
if Some(batch_requests.len()) == max_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty batch
|
// Empty batch
|
||||||
@ -451,6 +461,7 @@ enum QueueCommand {
|
|||||||
Append(Box<Entry>, Span),
|
Append(Box<Entry>, Span),
|
||||||
NextBatch {
|
NextBatch {
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
|
max_size: Option<usize>,
|
||||||
prefill_token_budget: u32,
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
response_sender: oneshot::Sender<Option<NextBatch>>,
|
response_sender: oneshot::Sender<Option<NextBatch>>,
|
||||||
@ -535,8 +546,8 @@ mod tests {
|
|||||||
fn test_next_batch_empty() {
|
fn test_next_batch_empty() {
|
||||||
let mut state = default_state();
|
let mut state = default_state();
|
||||||
|
|
||||||
assert!(state.next_batch(None, 1, 1).is_none());
|
assert!(state.next_batch(None, None, 1, 1).is_none());
|
||||||
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -547,7 +558,7 @@ mod tests {
|
|||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 2, 4).unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 2, 4).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -563,7 +574,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
assert!(state.next_batch(Some(2), 2, 2).is_none());
|
assert!(state.next_batch(Some(2), None, 2, 2).is_none());
|
||||||
|
|
||||||
assert_eq!(state.next_id, 3);
|
assert_eq!(state.next_id, 3);
|
||||||
assert_eq!(state.entries.len(), 1);
|
assert_eq!(state.entries.len(), 1);
|
||||||
@ -571,6 +582,26 @@ mod tests {
|
|||||||
assert_eq!(id, 2);
|
assert_eq!(id, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_next_batch_max_size() {
|
||||||
|
let mut state = default_state();
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
state.append(entry1);
|
||||||
|
state.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
|
||||||
|
assert_eq!(entries.len(), 1);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 1);
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 2);
|
||||||
|
assert_eq!(state.entries.len(), 1);
|
||||||
|
assert_eq!(state.next_batch_id, 1);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_token_budget() {
|
fn test_next_batch_token_budget() {
|
||||||
let mut state = default_state();
|
let mut state = default_state();
|
||||||
@ -579,7 +610,7 @@ mod tests {
|
|||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 1, 2).unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 1, 2).unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -592,7 +623,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 3, 6).unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 3, 6).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
@ -615,8 +646,8 @@ mod tests {
|
|||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = default_queue();
|
let queue = default_queue();
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -627,7 +658,7 @@ mod tests {
|
|||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
queue.append(entry2);
|
queue.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, None, 2, 4).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -640,11 +671,11 @@ mod tests {
|
|||||||
queue.append(entry3);
|
queue.append(entry3);
|
||||||
|
|
||||||
// Not enough requests pending
|
// Not enough requests pending
|
||||||
assert!(queue.next_batch(Some(2), 2, 2).await.is_none());
|
assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||||
// Not enough token budget
|
// Not enough token budget
|
||||||
assert!(queue.next_batch(Some(1), 0, 0).await.is_none());
|
assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
|
||||||
// Ok
|
// Ok
|
||||||
let (entries2, batch2, _) = queue.next_batch(Some(1), 1, 2).await.unwrap();
|
let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 4).await.unwrap();
|
||||||
assert_eq!(entries2.len(), 1);
|
assert_eq!(entries2.len(), 1);
|
||||||
assert!(entries2.contains_key(&2));
|
assert!(entries2.contains_key(&2));
|
||||||
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||||
@ -652,6 +683,22 @@ mod tests {
|
|||||||
assert_eq!(batch2.size, 1);
|
assert_eq!(batch2.size, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_max_size() {
|
||||||
|
let queue = default_queue();
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 1);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 1);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = default_queue();
|
let queue = default_queue();
|
||||||
@ -660,7 +707,7 @@ mod tests {
|
|||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
queue.append(entry2);
|
queue.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 1, 2).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, None, 1, 2).await.unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -669,7 +716,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
queue.append(entry3);
|
queue.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, None, 3, 6).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
@ -686,9 +733,9 @@ mod tests {
|
|||||||
queue.append(entry2);
|
queue.append(entry2);
|
||||||
|
|
||||||
// Budget of 1 is not enough
|
// Budget of 1 is not enough
|
||||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -702,6 +749,6 @@ mod tests {
|
|||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -770,6 +770,7 @@ pub async fn run(
|
|||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
tokenizer: Option<Tokenizer>,
|
tokenizer: Option<Tokenizer>,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
@ -851,6 +852,7 @@ pub async fn run(
|
|||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
max_input_length as u32,
|
max_input_length as u32,
|
||||||
@ -934,6 +936,7 @@ pub async fn run(
|
|||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
version: env!("CARGO_PKG_VERSION"),
|
version: env!("CARGO_PKG_VERSION"),
|
||||||
sha: option_env!("VERGEN_GIT_SHA"),
|
sha: option_env!("VERGEN_GIT_SHA"),
|
||||||
|
Loading…
Reference in New Issue
Block a user