This commit is contained in:
OlivierDehaene 2023-10-20 17:53:27 +02:00
parent f40f02fc25
commit 3f1cc9bad7

View File

@ -411,18 +411,22 @@ async fn decode(
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
tracing::info!("Decode");
match client.decode(batches).await {
Ok((generations, next_batch)) => {
// Update health
let _ = generation_health.send(true);
// Send generated tokens and filter stopped entries
tracing::info!("filter send");
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
tracing::info!("filter batch");
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
tracing::info!("Decode ended");
next_batch
}
// If we have an error, we discard the whole batch
@ -462,11 +466,13 @@ async fn filter_batch(
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
tracing::info!("clear cache");
client.clear_cache(Some(id)).await.unwrap();
None
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
tracing::info!("filter batch call");
client.filter_batch(id, batch.request_ids).await.unwrap()
}
}