Prefer prefill instead of decode when max_waiting_tokens==0 (#18)

This commit is contained in:
mrs303 2024-01-19 15:25:40 +01:00 committed by GitHub
parent 60f63262db
commit da0f874d49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -310,42 +310,54 @@ async fn batching_task(
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let mut token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
loop {
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
} else {
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
token_budget = token_budget.saturating_sub(new_cached_batch.max_tokens);
entries.extend(new_entries);
batches.push(new_cached_batch);
}
} else {
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
// Break as there is no batch
break;
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
// Loop again in case of max_waiting_tokens == 0
// to prefer doing next prefill. Break otherwise
if max_waiting_tokens != 0 {
break;
}
}