mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
use max_size in the batch task
This commit is contained in:
parent
9e042bd117
commit
2af011a1c0
@ -70,7 +70,7 @@ impl Infer {
|
|||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding, max_batch_size, 16, window_size, speculate);
|
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
@ -82,6 +82,7 @@ impl Infer {
|
|||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
shared.clone(),
|
shared.clone(),
|
||||||
generation_health,
|
generation_health,
|
||||||
@ -339,6 +340,7 @@ async fn batching_task(
|
|||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
@ -352,7 +354,12 @@ async fn batching_task(
|
|||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the queue
|
// waiting in the queue
|
||||||
while let Some((mut entries, batch, span)) = queue
|
while let Some((mut entries, batch, span)) = queue
|
||||||
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
|
.next_batch(
|
||||||
|
None,
|
||||||
|
max_batch_size,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
||||||
@ -380,10 +387,11 @@ async fn batching_task(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
let max_size = max_batch_size.map(|max_size| batch_size as usize - max_size);
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
.next_batch(min_size, max_batch_prefill_tokens, token_budget)
|
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
|
@ -36,7 +36,6 @@ pub(crate) struct Queue {
|
|||||||
impl Queue {
|
impl Queue {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
@ -47,7 +46,6 @@ impl Queue {
|
|||||||
// Launch background queue task
|
// Launch background queue task
|
||||||
tokio::spawn(queue_task(
|
tokio::spawn(queue_task(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
max_batch_size,
|
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
@ -72,6 +70,7 @@ impl Queue {
|
|||||||
pub(crate) async fn next_batch(
|
pub(crate) async fn next_batch(
|
||||||
&self,
|
&self,
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
|
max_size: Option<usize>,
|
||||||
prefill_token_budget: u32,
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
) -> Option<NextBatch> {
|
) -> Option<NextBatch> {
|
||||||
@ -82,6 +81,7 @@ impl Queue {
|
|||||||
self.queue_sender
|
self.queue_sender
|
||||||
.send(QueueCommand::NextBatch {
|
.send(QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
|
max_size,
|
||||||
prefill_token_budget,
|
prefill_token_budget,
|
||||||
token_budget,
|
token_budget,
|
||||||
response_sender,
|
response_sender,
|
||||||
@ -97,7 +97,6 @@ impl Queue {
|
|||||||
// Background task responsible of the queue state
|
// Background task responsible of the queue state
|
||||||
async fn queue_task(
|
async fn queue_task(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
max_size: Option<usize>,
|
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
@ -113,6 +112,7 @@ async fn queue_task(
|
|||||||
}
|
}
|
||||||
QueueCommand::NextBatch {
|
QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
|
max_size,
|
||||||
prefill_token_budget,
|
prefill_token_budget,
|
||||||
token_budget,
|
token_budget,
|
||||||
response_sender,
|
response_sender,
|
||||||
@ -332,6 +332,7 @@ enum QueueCommand {
|
|||||||
Append(Box<Entry>, Span),
|
Append(Box<Entry>, Span),
|
||||||
NextBatch {
|
NextBatch {
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
|
max_size: Option<usize>,
|
||||||
prefill_token_budget: u32,
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
response_sender: oneshot::Sender<Option<NextBatch>>,
|
response_sender: oneshot::Sender<Option<NextBatch>>,
|
||||||
@ -494,28 +495,28 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false, None, 1, None, 0);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false, None, 1, None, 0);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false, None, 1, None, 0);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
queue.append(entry2);
|
queue.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -528,11 +529,11 @@ mod tests {
|
|||||||
queue.append(entry3);
|
queue.append(entry3);
|
||||||
|
|
||||||
// Not enough requests pending
|
// Not enough requests pending
|
||||||
assert!(queue.next_batch(Some(2), 2, 2).await.is_none());
|
assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||||
// Not enough token budget
|
// Not enough token budget
|
||||||
assert!(queue.next_batch(Some(1), 0, 0).await.is_none());
|
assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
|
||||||
// Ok
|
// Ok
|
||||||
let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap();
|
let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
|
||||||
assert_eq!(entries2.len(), 1);
|
assert_eq!(entries2.len(), 1);
|
||||||
assert!(entries2.contains_key(&2));
|
assert!(entries2.contains_key(&2));
|
||||||
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||||
@ -542,13 +543,13 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_max_size() {
|
async fn test_queue_next_batch_max_size() {
|
||||||
let queue = Queue::new(false, Some(1), 1, None, 0);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
queue.append(entry2);
|
queue.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
@ -558,13 +559,13 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false, None, 1, None, 0);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
queue.append(entry2);
|
queue.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -573,7 +574,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
queue.append(entry3);
|
queue.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
@ -583,16 +584,16 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_speculate() {
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
let queue = Queue::new(false, None, 1, None, 2);
|
let queue = Queue::new(false, 1, None, 2);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
queue.append(entry2);
|
queue.append(entry2);
|
||||||
|
|
||||||
// Budget of 1 is not enough
|
// Budget of 1 is not enough
|
||||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -602,10 +603,10 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false, None, 1, None, 0);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user