Add tgi_batch_current_size and tgi_batch_current_size as response header

This commit is contained in:
Corentin REGAL 2025-01-17 15:48:02 +01:00
parent c20025dbf7
commit b4187d6022
5 changed files with 37 additions and 0 deletions

View File

@ -3,8 +3,10 @@ use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, Sharded
use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::server::BATCH_CURRENT_SIZE;
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
@ -155,6 +157,7 @@ pub(crate) async fn batching_task(
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
BATCH_CURRENT_SIZE.store(batch_size, SeqCst);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
@ -227,6 +230,7 @@ pub(crate) async fn batching_task(
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size").set(0.0);
BATCH_CURRENT_SIZE.store(0, SeqCst);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
}
}

View File

@ -4,8 +4,10 @@ use crate::client::{
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use std::sync::atomic::Ordering::SeqCst;
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::server::QUEUE_SIZE;
use text_generation_router::validation::{
ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
@ -112,6 +114,7 @@ async fn queue_task(
QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry));
metrics::gauge!("tgi_queue_size").increment(1.0);
QUEUE_SIZE.fetch_add(1, SeqCst);
}
QueueCommand::NextBatch {
min_size,
@ -125,6 +128,7 @@ async fn queue_task(
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
QUEUE_SIZE.store(state.entries.len() as u32, SeqCst);
}),
}
}

View File

@ -5,8 +5,10 @@ use crate::client::{
use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::server::BATCH_CURRENT_SIZE;
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
@ -164,6 +166,7 @@ pub(crate) async fn batching_task(
let current_tokens = batch.current_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
BATCH_CURRENT_SIZE.store(batch_size, SeqCst);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
@ -284,6 +287,7 @@ pub(crate) async fn batching_task(
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size").set(0.0);
BATCH_CURRENT_SIZE.store(0, SeqCst);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
}
}

View File

@ -6,8 +6,10 @@ use crate::client::{
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::max;
use std::collections::VecDeque;
use std::sync::atomic::Ordering::SeqCst;
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::server::QUEUE_SIZE;
use text_generation_router::validation::{
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
ValidStoppingParameters,
@ -140,6 +142,7 @@ async fn queue_task(
QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry));
metrics::gauge!("tgi_queue_size").increment(1.0);
QUEUE_SIZE.fetch_add(1, SeqCst);
}
QueueCommand::NextBatch {
min_size,
@ -155,6 +158,7 @@ async fn queue_task(
.await;
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
QUEUE_SIZE.store(state.entries.len() as u32, SeqCst);
}
}
}

View File

@ -54,6 +54,8 @@ use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::SeqCst;
use thiserror::Error;
use tokio::select;
use tokio::signal;
@ -400,6 +402,14 @@ pub(crate) async fn generate_internal(
"x-generated-tokens",
response.generated_text.generated_tokens.into(),
);
headers.insert(
"x-batch-current-size",
BATCH_CURRENT_SIZE.load(SeqCst).to_string().parse().unwrap(),
);
headers.insert(
"x-queue-size",
QUEUE_SIZE.load(SeqCst).to_string().parse().unwrap(),
);
// Metrics
metrics::counter!("tgi_request_success").increment(1);
@ -490,6 +500,9 @@ async fn generate_stream(
(headers, sse)
}
pub static BATCH_CURRENT_SIZE: AtomicU32 = AtomicU32::new(0);
pub static QUEUE_SIZE: AtomicU32 = AtomicU32::new(0);
async fn generate_stream_internal(
infer: Infer,
ComputeType(compute_type): ComputeType,
@ -513,6 +526,14 @@ async fn generate_stream_internal(
compute_characters.to_string().parse().unwrap(),
);
headers.insert("X-Accel-Buffering", "no".parse().unwrap());
headers.insert(
"x-batch-current-size",
BATCH_CURRENT_SIZE.load(SeqCst).to_string().parse().unwrap(),
);
headers.insert(
"x-queue-size",
QUEUE_SIZE.load(SeqCst).to_string().parse().unwrap(),
);
let stream = async_stream::stream! {
// Inference