use max_size in the batch task

This commit is contained in:
OlivierDehaene 2024-02-08 17:26:55 +01:00
parent 9e042bd117
commit 2af011a1c0
2 changed files with 34 additions and 25 deletions

View File

@ -70,7 +70,7 @@ impl Infer {
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(requires_padding, max_batch_size, 16, window_size, speculate); let queue = Queue::new(requires_padding, 16, window_size, speculate);
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
batching_task: Notify::new(), batching_task: Notify::new(),
}); });
@ -82,6 +82,7 @@ impl Infer {
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size,
queue.clone(), queue.clone(),
shared.clone(), shared.clone(),
generation_health, generation_health,
@ -339,6 +340,7 @@ async fn batching_task(
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>,
queue: Queue, queue: Queue,
shared: Arc<Shared>, shared: Arc<Shared>,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
@ -352,7 +354,12 @@ async fn batching_task(
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue // waiting in the queue
while let Some((mut entries, batch, span)) = queue while let Some((mut entries, batch, span)) = queue
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens) .next_batch(
None,
max_batch_size,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await .await
{ {
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
@ -380,10 +387,11 @@ async fn batching_task(
}; };
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| batch_size as usize - max_size);
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_prefill_tokens, token_budget) .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
.await .await
{ {
// Tracking metrics // Tracking metrics

View File

@ -36,7 +36,6 @@ pub(crate) struct Queue {
impl Queue { impl Queue {
pub(crate) fn new( pub(crate) fn new(
requires_padding: bool, requires_padding: bool,
max_batch_size: Option<usize>,
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
@ -47,7 +46,6 @@ impl Queue {
// Launch background queue task // Launch background queue task
tokio::spawn(queue_task( tokio::spawn(queue_task(
requires_padding, requires_padding,
max_batch_size,
block_size, block_size,
window_size, window_size,
speculate, speculate,
@ -72,6 +70,7 @@ impl Queue {
pub(crate) async fn next_batch( pub(crate) async fn next_batch(
&self, &self,
min_size: Option<usize>, min_size: Option<usize>,
max_size: Option<usize>,
prefill_token_budget: u32, prefill_token_budget: u32,
token_budget: u32, token_budget: u32,
) -> Option<NextBatch> { ) -> Option<NextBatch> {
@ -82,6 +81,7 @@ impl Queue {
self.queue_sender self.queue_sender
.send(QueueCommand::NextBatch { .send(QueueCommand::NextBatch {
min_size, min_size,
max_size,
prefill_token_budget, prefill_token_budget,
token_budget, token_budget,
response_sender, response_sender,
@ -97,7 +97,6 @@ impl Queue {
// Background task responsible of the queue state // Background task responsible of the queue state
async fn queue_task( async fn queue_task(
requires_padding: bool, requires_padding: bool,
max_size: Option<usize>,
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
@ -113,6 +112,7 @@ async fn queue_task(
} }
QueueCommand::NextBatch { QueueCommand::NextBatch {
min_size, min_size,
max_size,
prefill_token_budget, prefill_token_budget,
token_budget, token_budget,
response_sender, response_sender,
@ -332,6 +332,7 @@ enum QueueCommand {
Append(Box<Entry>, Span), Append(Box<Entry>, Span),
NextBatch { NextBatch {
min_size: Option<usize>, min_size: Option<usize>,
max_size: Option<usize>,
prefill_token_budget: u32, prefill_token_budget: u32,
token_budget: u32, token_budget: u32,
response_sender: oneshot::Sender<Option<NextBatch>>, response_sender: oneshot::Sender<Option<NextBatch>>,
@ -494,28 +495,28 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false, None, 1, None, 0); let queue = Queue::new(false, 1, None, 0);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, None, 1, None, 0); let queue = Queue::new(false, 1, None, 0);
assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, None, 1, None, 0); let queue = Queue::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
queue.append(entry2); queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
@ -528,11 +529,11 @@ mod tests {
queue.append(entry3); queue.append(entry3);
// Not enough requests pending // Not enough requests pending
assert!(queue.next_batch(Some(2), 2, 2).await.is_none()); assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
// Not enough token budget // Not enough token budget
assert!(queue.next_batch(Some(1), 0, 0).await.is_none()); assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
// Ok // Ok
let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap(); let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
assert_eq!(entries2.len(), 1); assert_eq!(entries2.len(), 1);
assert!(entries2.contains_key(&2)); assert!(entries2.contains_key(&2));
assert!(entries2.get(&2).unwrap().batch_time.is_some()); assert!(entries2.get(&2).unwrap().batch_time.is_some());
@ -542,13 +543,13 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, Some(1), 1, None, 0); let queue = Queue::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
queue.append(entry2); queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.get(&0).unwrap().batch_time.is_some()); assert!(entries.get(&0).unwrap().batch_time.is_some());
@ -558,13 +559,13 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, None, 1, None, 0); let queue = Queue::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
queue.append(entry2); queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
@ -573,7 +574,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
queue.append(entry3); queue.append(entry3);
let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));
@ -583,16 +584,16 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, None, 1, None, 2); let queue = Queue::new(false, 1, None, 2);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
queue.append(entry2); queue.append(entry2);
// Budget of 1 is not enough // Budget of 1 is not enough
assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
@ -602,10 +603,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, None, 1, None, 0); let queue = Queue::new(false, 1, None, 0);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);
assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
} }
} }