From bb8f59632f44e2992fc5a2cfa7fa809c4a14a874 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 27 Feb 2025 14:32:51 +0100 Subject: [PATCH 1/5] feat(metrics): exposes queue size as tokens along with individual requests count --- backends/v3/src/queue.rs | 11 ++++++++++- router/src/server.rs | 6 ++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 249eebf7..c0ac3e87 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -138,8 +138,9 @@ async fn queue_task( while let Some(cmd) = receiver.recv().await { match cmd { QueueCommand::Append(entry, span) => { - span.in_scope(|| state.append(*entry)); metrics::gauge!("tgi_queue_size").increment(1.0); + metrics::gauge!("tgi_queue_size_tokens").increment(entry.request.input_length); + span.in_scope(|| state.append(*entry)); } QueueCommand::NextBatch { min_size, @@ -154,7 +155,15 @@ async fn queue_task( .instrument(span) .await; response_sender.send(next_batch).unwrap(); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); + metrics::gauge!("tgi_queue_size_tokens").set( + state + .entries + .iter() + .map(|(_, e)| e.request.input_length as f64) + .sum::(), + ); } } } diff --git a/router/src/server.rs b/router/src/server.rs index e9aa4612..29392771 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -67,6 +67,7 @@ use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; + fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec { let offsets = encoding.get_offsets(); let input_ids = encoding.get_ids(); @@ -2171,6 +2172,11 @@ async fn start( "Current batch size" ); metrics::describe_gauge!("tgi_queue_size", metrics::Unit::Count, "Current queue size"); + metrics::describe_gauge!( + "tgi_queue_size_tokens", + metrics::Unit::Count, + "Current queue size in number of tokens" + ); metrics::describe_gauge!( "tgi_batch_current_max_tokens", metrics::Unit::Count, From 8de41f63a8bf53b4b99e855753f0b69d82d2490a Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 27 Feb 2025 16:58:02 +0100 Subject: [PATCH 2/5] feat(metrics): exposes the engine state as an endpoint --- backends/v3/Cargo.toml | 16 ++++++++-------- backends/v3/src/backend.rs | 14 +++++++++++++- router/Cargo.toml | 19 ++++++++++--------- router/src/infer/mod.rs | 29 +++++++++++++++++++++++++++++ router/src/server.rs | 25 +++++++++++++++++++++++++ 5 files changed, 85 insertions(+), 18 deletions(-) diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 996290ed..1489efb4 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -18,7 +18,7 @@ async-trait = "0.1.74" async-stream = "0.3.5" axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" -text-generation-router = { path = "../../router" } +text-generation-router = { path = "../../router", features = ["engine-state", "ngrok"] } clap = { version = "4.4.5", features = ["derive", "env"] } grpc-metadata = { path = "../grpc-metadata" } futures = "0.3.28" @@ -37,13 +37,13 @@ slotmap = "1.0.7" thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = [ - "rt", - "rt-multi-thread", - "parking_lot", - "signal", - "sync", + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", ] } -tokio-stream = "0.1.14" +tokio-stream = { version = "0.1.14", features = ["sync"] } tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.21.0" @@ -51,7 +51,7 @@ tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } init-tracing-opentelemetry = { version = "0.14.1", features = [ - "opentelemetry-otlp", + "opentelemetry-otlp", ] } minijinja = { workspace = true } minijinja-contrib = { workspace = true } diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 98e8d76f..6b3993bb 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -6,16 +6,23 @@ use crate::queue::{Entry, Queue}; use async_trait::async_trait; use nohash_hasher::IntMap; use std::sync::Arc; -use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::infer::{ + Backend, EngineState, GeneratedText, InferError, InferStreamResponse, +}; use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::{FinishReason, PrefillToken, Token}; +use tokio::sync::broadcast::{channel, Receiver, Sender}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; + pub struct BackendV3 { + /// Events streaming channel + events: (Sender, Receiver), + /// Request queue queue: Queue, /// Notify batcher on queue appends @@ -66,6 +73,7 @@ impl BackendV3 { )); Self { + events: channel(1), queue, batching_task_notifier, client, @@ -112,6 +120,10 @@ impl Backend for BackendV3 { .is_ok() } + fn events(&self) -> Receiver { + self.events.0.subscribe() + } + fn start_health(&self) -> bool { true } diff --git a/router/Cargo.toml b/router/Cargo.toml index 9326258d..bdaefcc7 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -31,11 +31,11 @@ serde_json = "1.0.107" thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = [ - "rt", - "rt-multi-thread", - "parking_lot", - "signal", - "sync", + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", ] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } @@ -46,7 +46,7 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ - "opentelemetry-otlp", + "opentelemetry-otlp", ] } minijinja = { workspace = true, features = ["loop_controls"] } minijinja-contrib = { workspace = true } @@ -57,9 +57,9 @@ image = "0.25.1" base64 = { workspace = true } sysinfo = "0.30.13" uuid = { version = "1.9.1", default-features = false, features = [ - "v4", - "fast-rng", - "macro-diagnostics", + "v4", + "fast-rng", + "macro-diagnostics", ] } csv = "1.3.0" ureq = "=2.9" @@ -73,5 +73,6 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } [features] default = ["ngrok"] ngrok = ["dep:ngrok"] +engine-state = [] google = [] kserve = [] diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 7eb8a41b..bc958fee 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -19,12 +19,30 @@ use serde::Serialize; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use thiserror::Error; +use tokio::sync::broadcast::Receiver; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::instrument; +/// Store real-time information about batch engine usage (expressed in tokens) +#[cfg(feature = "engine-state")] +#[derive(Debug, Copy, Clone, Serialize)] +pub struct EngineState { + /// Number of tokens currently participating in current batch + in_flight: u32, + + /// Maximum number of tokens which can participate in a batch + in_flight_max: u32, + + /// Number of tokens currently waiting in the queue for future batching + in_queue: u32, + + /// Maximum number of tokens which can wait in the queue for future batching + in_queue_max: u32, +} + #[async_trait] pub trait Backend { fn schedule( @@ -34,6 +52,11 @@ pub trait Backend { async fn health(&self, current_health: bool) -> bool; + /// Gets a reference to receiving-side channel generating events about the current internal + /// batching engine state + #[cfg(feature = "engine-state")] + fn events(&self) -> Receiver; + /// The state of the health on startup /// Typically false, or true if the backend includes /// a warmup phase. @@ -95,6 +118,12 @@ impl Infer { } } + #[cfg(feature = "engine-state")] + #[inline] + pub(crate) fn events(&self) -> Receiver { + self.backend.events() + } + /// Add a new request to the queue and return a stream of InferStreamResponse #[instrument(skip_all)] pub(crate) async fn generate_stream<'a>( diff --git a/router/src/server.rs b/router/src/server.rs index 29392771..b952c62a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -62,6 +62,8 @@ use tokio::select; use tokio::signal; use tokio::sync::oneshot; use tokio::time::Instant; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; +use tokio_stream::wrappers::BroadcastStream; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; @@ -1502,6 +1504,28 @@ async fn metrics(prom_handle: Extension) -> String { prom_handle.render() } +#[utoipa::path(get, tag = "Text Generation Inference", path = "/state")] +#[instrument(skip_all)] +async fn state( + Extension(infer): Extension, +) -> Result>>, StatusCode> { + if cfg!(feature = "engine-state") { + let stream = infer.events(); + let sse = + Sse::new(BroadcastStream::from(stream).map(|state| { + Event::default().json_data(state.map_err(|err| axum::Error::new(err))?) + })) + .keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(5)) + .text("more_open_models_on_hf"), + ); + Ok(sse) + } else { + Err(StatusCode::NOT_IMPLEMENTED) + } +} + #[derive(Clone, Debug)] pub(crate) struct ComputeType(String); @@ -1521,6 +1545,7 @@ metrics, openai_get_model_info, sagemaker_compatibility, get_chat_tokenize, +state, ), components( schemas( From 1a9c5dec76bc4a7bb875a2584efcdf5dc7719aa8 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 27 Feb 2025 21:33:41 +0100 Subject: [PATCH 3/5] feat(metrics): update Cargo.lock --- Cargo.lock | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.lock b/Cargo.lock index 4603f77d..1b73ff57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5101,6 +5101,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] From 712199c769f85ab062c9cd84d151a3c8c4efc372 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 27 Feb 2025 22:43:20 +0100 Subject: [PATCH 4/5] feat(metrics): dispatch internal engine state event from queuing/batching tasks --- backends/v3/src/backend.rs | 48 ++++++++++++++++++++++---- backends/v3/src/queue.rs | 69 ++++++++++++++++++++++++++++++++------ router/src/infer/mod.rs | 21 +++++++++--- router/src/server.rs | 2 +- 4 files changed, 117 insertions(+), 23 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 6b3993bb..bfa87773 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -11,22 +11,27 @@ use text_generation_router::infer::{ }; use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::{FinishReason, PrefillToken, Token}; -use tokio::sync::broadcast::{channel, Receiver, Sender}; +use tokio::sync::broadcast::{channel, Receiver as BroadcastReceiver, Sender as BroadcastSender}; use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify}; +use tokio::sync::{mpsc, Notify, RwLock}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; pub struct BackendV3 { - /// Events streaming channel - events: (Sender, Receiver), + /// Internal batching state exposing info for the proxy + state: Arc>, /// Request queue queue: Queue, + + /// Events streaming channel + state_events: (BroadcastSender, BroadcastReceiver), + /// Notify batcher on queue appends batching_task_notifier: Arc, + /// Client clone, used for health checks to skip the queue client: ShardedClient, } @@ -48,6 +53,12 @@ impl BackendV3 { let block_size = shard_info.block_size; + let state_events = channel(1); + let state = Arc::new(RwLock::new(EngineState::new( + max_batch_total_tokens, + 2 * max_batch_total_tokens, + ))); + let queue = Queue::new( shard_info.requires_padding, block_size, @@ -56,6 +67,8 @@ impl BackendV3 { shard_info.speculate, max_batch_total_tokens, shard_info.support_chunking, + Arc::clone(&state), + state_events.0.clone(), ); let batching_task_notifier = Arc::new(Notify::new()); @@ -69,11 +82,14 @@ impl BackendV3 { max_batch_size, shard_info.support_chunking, queue.clone(), + state.clone(), + state_events.0.clone(), batching_task_notifier.clone(), )); Self { - events: channel(1), + state, + state_events, queue, batching_task_notifier, client, @@ -120,8 +136,8 @@ impl Backend for BackendV3 { .is_ok() } - fn events(&self) -> Receiver { - self.events.0.subscribe() + fn events(&self) -> BroadcastReceiver { + self.state_events.0.subscribe() } fn start_health(&self) -> bool { @@ -147,6 +163,8 @@ pub(crate) async fn batching_task( max_batch_size: Option, support_chunking: bool, queue: Queue, + engine_state: Arc>, + batch_events: BroadcastSender, notifier: Arc, ) { // Infinite loop @@ -182,6 +200,22 @@ pub(crate) async fn batching_task( metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); + // Dispatch new state to the proxy + { + // Critical section, doing as little as possible here + { + let mut engine_state = engine_state.write().await; + engine_state.in_flight = batch_max_tokens; + } + + // Send new state to the channel for broadcasting + if let Err(err) = batch_events.send(*engine_state.read().await) { + tracing::warn!( + "Failed to send BatchEvent::BatchChanged({batch_max_tokens}): {err}" + ) + } + } + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let (min_size, max_size, prefill_token_budget) = if support_chunking { diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index c0ac3e87..71158cf0 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -6,14 +6,17 @@ use crate::client::{ use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::max; use std::collections::VecDeque; -use text_generation_router::infer::InferError; +use std::sync::Arc; use text_generation_router::infer::InferStreamResponse; +use text_generation_router::infer::{EngineState, InferError}; use text_generation_router::validation::{ Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, }; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::broadcast::Sender as BroadcastSender; +use tokio::sync::{mpsc, oneshot, RwLock}; use tokio::time::Instant; +use tracing::log::warn; use tracing::{info_span, instrument, Instrument, Span}; /// Queue entry @@ -51,6 +54,8 @@ impl Queue { speculate: u32, max_batch_total_tokens: u32, support_chunking: bool, + engine_state: Arc>, + queue_events: BroadcastSender, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -64,7 +69,9 @@ impl Queue { speculate, max_batch_total_tokens, support_chunking, + engine_state, queue_receiver, + queue_events, )); Self { queue_sender } @@ -113,7 +120,7 @@ impl Queue { } } -// Background task responsible of the queue state +// Background task responsible for the queue state #[allow(clippy::too_many_arguments)] async fn queue_task( requires_padding: bool, @@ -123,7 +130,9 @@ async fn queue_task( speculate: u32, max_batch_total_tokens: u32, support_chunking: bool, + engine_state: Arc>, mut receiver: mpsc::UnboundedReceiver, + queue_events: BroadcastSender, ) { let mut state = State::new( requires_padding, @@ -138,9 +147,29 @@ async fn queue_task( while let Some(cmd) = receiver.recv().await { match cmd { QueueCommand::Append(entry, span) => { - metrics::gauge!("tgi_queue_size").increment(1.0); - metrics::gauge!("tgi_queue_size_tokens").increment(entry.request.input_length); + let entry_num_tokens = entry.request.input_length; span.in_scope(|| state.append(*entry)); + metrics::gauge!("tgi_queue_size").increment(1.0); + metrics::gauge!("tgi_queue_size_tokens").increment(entry_num_tokens); + + // Dispatch new state to the proxy + { + // Lock free operation (read) + let num_queued_tokens = engine_state.read().await.in_queue; + { + // Critical section, doing as little as possible here + let mut engine_state = engine_state.write().await; + engine_state.in_queue = num_queued_tokens + entry_num_tokens; + } + + // Send new state to the channel for broadcasting + if let Err(err) = queue_events.send(*engine_state.read().await) { + tracing::warn!( + "Failed to send BatchEvent::QueueChanged({}): {err}", + num_queued_tokens + entry_num_tokens + ) + } + } } QueueCommand::NextBatch { min_size, @@ -156,14 +185,32 @@ async fn queue_task( .await; response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); - metrics::gauge!("tgi_queue_size_tokens").set( - state + { + let num_batch_tokens = state .entries .iter() - .map(|(_, e)| e.request.input_length as f64) - .sum::(), - ); + .map(|(_, e)| e.request.input_length) + .sum::(); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); + metrics::gauge!("tgi_queue_size_tokens").set(num_batch_tokens as f64); + + // Dispatch new state to the proxy + { + // Critical section, doing as little as possible here + { + let mut engine_state = engine_state.write().await; + engine_state.in_queue = num_batch_tokens; + } + + // Send new state to the channel for broadcasting + if let Err(err) = queue_events.send(*engine_state.read().await) { + tracing::warn!( + "Failed to send BatchEvent::QueueChanged({}): {err}", + num_batch_tokens + ) + } + } + } } } } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index bc958fee..0123f54f 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -31,16 +31,29 @@ use tracing::instrument; #[derive(Debug, Copy, Clone, Serialize)] pub struct EngineState { /// Number of tokens currently participating in current batch - in_flight: u32, + pub in_flight: u32, /// Maximum number of tokens which can participate in a batch - in_flight_max: u32, + pub in_flight_max: u32, /// Number of tokens currently waiting in the queue for future batching - in_queue: u32, + pub in_queue: u32, /// Maximum number of tokens which can wait in the queue for future batching - in_queue_max: u32, + pub in_queue_max: u32, +} + +#[cfg(feature = "engine-state")] +impl EngineState { + #[inline] + pub fn new(in_flight_max: u32, in_queue_max: u32) -> Self { + EngineState { + in_flight: 0, + in_flight_max, + in_queue: 0, + in_queue_max, + } + } } #[async_trait] diff --git a/router/src/server.rs b/router/src/server.rs index b952c62a..c394f7b8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -62,7 +62,6 @@ use tokio::select; use tokio::signal; use tokio::sync::oneshot; use tokio::time::Instant; -use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::BroadcastStream; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; @@ -2398,6 +2397,7 @@ async fn start( .route("/health", get(health)) .route("/ping", get(health)) .route("/metrics", get(metrics)) + .route("/state", get(state)) .route("/v1/models", get(openai_get_model_info)); let compute_type = From f72547c9fbe2a3d789ff23648a6c25430fad1209 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 27 Feb 2025 22:56:04 +0100 Subject: [PATCH 5/5] feat(metrics): remove ngrok mandatory feature for backendv3 crate --- backends/v3/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 1489efb4..da8e8353 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -18,7 +18,7 @@ async-trait = "0.1.74" async-stream = "0.3.5" axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" -text-generation-router = { path = "../../router", features = ["engine-state", "ngrok"] } +text-generation-router = { path = "../../router", features = ["engine-state"] } clap = { version = "4.4.5", features = ["derive", "env"] } grpc-metadata = { path = "../grpc-metadata" } futures = "0.3.28"