mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
fix: batching
This commit is contained in:
parent
12590fdcce
commit
3756a5f1e2
@ -257,6 +257,7 @@ impl Infer {
|
|||||||
///
|
///
|
||||||
/// Batches requests and sends them to the inference server
|
/// Batches requests and sends them to the inference server
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
#[instrument(skip_all)]
|
||||||
async fn batching_task(
|
async fn batching_task(
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
@ -275,6 +276,7 @@ async fn batching_task(
|
|||||||
// Get the next batch from the queue
|
// Get the next batch from the queue
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the queue
|
// waiting in the queue
|
||||||
|
tracing::debug!("First batch");
|
||||||
while let Some((mut entries, batch, span)) = queue
|
while let Some((mut entries, batch, span)) = queue
|
||||||
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
|
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
|
||||||
.await
|
.await
|
||||||
@ -378,6 +380,8 @@ async fn prefill(
|
|||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||||
|
|
||||||
|
tracing::debug!("Prefill");
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
// Update health
|
// Update health
|
||||||
@ -415,6 +419,8 @@ async fn decode(
|
|||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||||
|
|
||||||
|
tracing::debug!("Decode");
|
||||||
|
|
||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
// Update health
|
// Update health
|
||||||
@ -513,6 +519,7 @@ fn send_responses(
|
|||||||
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
|
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
|
||||||
// Return directly if the channel is disconnected
|
// Return directly if the channel is disconnected
|
||||||
if entry.response_tx.is_disconnected() {
|
if entry.response_tx.is_disconnected() {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
return Ok(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,7 +82,9 @@ impl Queue {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
// Await on response channel
|
// Await on response channel
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
response_receiver.await.unwrap()
|
let response = response_receiver.await.unwrap();
|
||||||
|
tracing::debug!("Next batch: {}", response.is_some());
|
||||||
|
response
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user