fix: batching

This commit is contained in:
OlivierDehaene 2023-10-20 11:53:23 +02:00
parent 12590fdcce
commit 3756a5f1e2
2 changed files with 10 additions and 1 deletions

View File

@ -257,6 +257,7 @@ impl Infer {
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
#[instrument(skip_all)]
async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
@ -275,6 +276,7 @@ async fn batching_task(
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
tracing::debug!("First batch");
while let Some((mut entries, batch, span)) = queue
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
.await
@ -378,6 +380,8 @@ async fn prefill(
let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
tracing::debug!("Prefill");
match client.prefill(batch).await {
Ok((generations, next_batch)) => {
// Update health
@ -415,6 +419,8 @@ async fn decode(
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
tracing::debug!("Decode");
match client.decode(batches).await {
Ok((generations, next_batch)) => {
// Update health
@ -513,6 +519,7 @@ fn send_responses(
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_disconnected() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
return Ok(true);
}

View File

@ -82,7 +82,9 @@ impl Queue {
.unwrap();
// Await on response channel
// Unwrap is safe here
response_receiver.await.unwrap()
let response = response_receiver.await.unwrap();
tracing::debug!("Next batch: {}", response.is_some());
response
}
}