From b4187d602217aafb7fe345d22ab843d48acf3393 Mon Sep 17 00:00:00 2001 From: Corentin REGAL Date: Fri, 17 Jan 2025 15:48:02 +0100 Subject: [PATCH] Add tgi_batch_current_size and tgi_batch_current_size as response header --- backends/v2/src/backend.rs | 4 ++++ backends/v2/src/queue.rs | 4 ++++ backends/v3/src/backend.rs | 4 ++++ backends/v3/src/queue.rs | 4 ++++ router/src/server.rs | 21 +++++++++++++++++++++ 5 files changed, 37 insertions(+) diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs index cfe87f98f..69170a80f 100644 --- a/backends/v2/src/backend.rs +++ b/backends/v2/src/backend.rs @@ -3,8 +3,10 @@ use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, Sharded use crate::queue::{Entry, Queue}; use async_trait::async_trait; use nohash_hasher::IntMap; +use std::sync::atomic::Ordering::SeqCst; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::server::BATCH_CURRENT_SIZE; use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::{FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; @@ -155,6 +157,7 @@ pub(crate) async fn batching_task( let batch_max_tokens = batch.max_tokens; let mut batches = vec![batch]; metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + BATCH_CURRENT_SIZE.store(batch_size, SeqCst); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); let min_size = if waiting_tokens >= max_waiting_tokens { @@ -227,6 +230,7 @@ pub(crate) async fn batching_task( waiting_tokens += 1; } metrics::gauge!("tgi_batch_current_size").set(0.0); + BATCH_CURRENT_SIZE.store(0, SeqCst); metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); } } diff --git a/backends/v2/src/queue.rs b/backends/v2/src/queue.rs index c9a9335dd..6603d9c1e 100644 --- a/backends/v2/src/queue.rs +++ b/backends/v2/src/queue.rs @@ -4,8 +4,10 @@ use crate::client::{ use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use std::collections::VecDeque; +use std::sync::atomic::Ordering::SeqCst; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; +use text_generation_router::server::QUEUE_SIZE; use text_generation_router::validation::{ ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, }; @@ -112,6 +114,7 @@ async fn queue_task( QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); metrics::gauge!("tgi_queue_size").increment(1.0); + QUEUE_SIZE.fetch_add(1, SeqCst); } QueueCommand::NextBatch { min_size, @@ -125,6 +128,7 @@ async fn queue_task( state.next_batch(min_size, max_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); + QUEUE_SIZE.store(state.entries.len() as u32, SeqCst); }), } } diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 736301b33..b5e372197 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -5,8 +5,10 @@ use crate::client::{ use crate::queue::{Entry, Queue}; use async_trait::async_trait; use nohash_hasher::IntMap; +use std::sync::atomic::Ordering::SeqCst; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::server::BATCH_CURRENT_SIZE; use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::{FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; @@ -164,6 +166,7 @@ pub(crate) async fn batching_task( let current_tokens = batch.current_tokens; let mut batches = vec![batch]; metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + BATCH_CURRENT_SIZE.store(batch_size, SeqCst); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); @@ -284,6 +287,7 @@ pub(crate) async fn batching_task( waiting_tokens += 1; } metrics::gauge!("tgi_batch_current_size").set(0.0); + BATCH_CURRENT_SIZE.store(0, SeqCst); metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); } } diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 249eebf76..fecc8501e 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -6,8 +6,10 @@ use crate::client::{ use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::max; use std::collections::VecDeque; +use std::sync::atomic::Ordering::SeqCst; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; +use text_generation_router::server::QUEUE_SIZE; use text_generation_router::validation::{ Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, @@ -140,6 +142,7 @@ async fn queue_task( QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); metrics::gauge!("tgi_queue_size").increment(1.0); + QUEUE_SIZE.fetch_add(1, SeqCst); } QueueCommand::NextBatch { min_size, @@ -155,6 +158,7 @@ async fn queue_task( .await; response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); + QUEUE_SIZE.store(state.entries.len() as u32, SeqCst); } } } diff --git a/router/src/server.rs b/router/src/server.rs index aef0f8120..576b87567 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -54,6 +54,8 @@ use std::fs::File; use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; +use std::sync::atomic::AtomicU32; +use std::sync::atomic::Ordering::SeqCst; use thiserror::Error; use tokio::select; use tokio::signal; @@ -400,6 +402,14 @@ pub(crate) async fn generate_internal( "x-generated-tokens", response.generated_text.generated_tokens.into(), ); + headers.insert( + "x-batch-current-size", + BATCH_CURRENT_SIZE.load(SeqCst).to_string().parse().unwrap(), + ); + headers.insert( + "x-queue-size", + QUEUE_SIZE.load(SeqCst).to_string().parse().unwrap(), + ); // Metrics metrics::counter!("tgi_request_success").increment(1); @@ -490,6 +500,9 @@ async fn generate_stream( (headers, sse) } +pub static BATCH_CURRENT_SIZE: AtomicU32 = AtomicU32::new(0); +pub static QUEUE_SIZE: AtomicU32 = AtomicU32::new(0); + async fn generate_stream_internal( infer: Infer, ComputeType(compute_type): ComputeType, @@ -513,6 +526,14 @@ async fn generate_stream_internal( compute_characters.to_string().parse().unwrap(), ); headers.insert("X-Accel-Buffering", "no".parse().unwrap()); + headers.insert( + "x-batch-current-size", + BATCH_CURRENT_SIZE.load(SeqCst).to_string().parse().unwrap(), + ); + headers.insert( + "x-queue-size", + QUEUE_SIZE.load(SeqCst).to_string().parse().unwrap(), + ); let stream = async_stream::stream! { // Inference