diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 6b3993bb..bfa87773 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -11,22 +11,27 @@ use text_generation_router::infer::{ }; use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::{FinishReason, PrefillToken, Token}; -use tokio::sync::broadcast::{channel, Receiver, Sender}; +use tokio::sync::broadcast::{channel, Receiver as BroadcastReceiver, Sender as BroadcastSender}; use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify}; +use tokio::sync::{mpsc, Notify, RwLock}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; pub struct BackendV3 { - /// Events streaming channel - events: (Sender, Receiver), + /// Internal batching state exposing info for the proxy + state: Arc>, /// Request queue queue: Queue, + + /// Events streaming channel + state_events: (BroadcastSender, BroadcastReceiver), + /// Notify batcher on queue appends batching_task_notifier: Arc, + /// Client clone, used for health checks to skip the queue client: ShardedClient, } @@ -48,6 +53,12 @@ impl BackendV3 { let block_size = shard_info.block_size; + let state_events = channel(1); + let state = Arc::new(RwLock::new(EngineState::new( + max_batch_total_tokens, + 2 * max_batch_total_tokens, + ))); + let queue = Queue::new( shard_info.requires_padding, block_size, @@ -56,6 +67,8 @@ impl BackendV3 { shard_info.speculate, max_batch_total_tokens, shard_info.support_chunking, + Arc::clone(&state), + state_events.0.clone(), ); let batching_task_notifier = Arc::new(Notify::new()); @@ -69,11 +82,14 @@ impl BackendV3 { max_batch_size, shard_info.support_chunking, queue.clone(), + state.clone(), + state_events.0.clone(), batching_task_notifier.clone(), )); Self { - events: channel(1), + state, + state_events, queue, batching_task_notifier, client, @@ -120,8 +136,8 @@ impl Backend for BackendV3 { .is_ok() } - fn events(&self) -> Receiver { - self.events.0.subscribe() + fn events(&self) -> BroadcastReceiver { + self.state_events.0.subscribe() } fn start_health(&self) -> bool { @@ -147,6 +163,8 @@ pub(crate) async fn batching_task( max_batch_size: Option, support_chunking: bool, queue: Queue, + engine_state: Arc>, + batch_events: BroadcastSender, notifier: Arc, ) { // Infinite loop @@ -182,6 +200,22 @@ pub(crate) async fn batching_task( metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); + // Dispatch new state to the proxy + { + // Critical section, doing as little as possible here + { + let mut engine_state = engine_state.write().await; + engine_state.in_flight = batch_max_tokens; + } + + // Send new state to the channel for broadcasting + if let Err(err) = batch_events.send(*engine_state.read().await) { + tracing::warn!( + "Failed to send BatchEvent::BatchChanged({batch_max_tokens}): {err}" + ) + } + } + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let (min_size, max_size, prefill_token_budget) = if support_chunking { diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index c0ac3e87..71158cf0 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -6,14 +6,17 @@ use crate::client::{ use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::max; use std::collections::VecDeque; -use text_generation_router::infer::InferError; +use std::sync::Arc; use text_generation_router::infer::InferStreamResponse; +use text_generation_router::infer::{EngineState, InferError}; use text_generation_router::validation::{ Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, }; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::broadcast::Sender as BroadcastSender; +use tokio::sync::{mpsc, oneshot, RwLock}; use tokio::time::Instant; +use tracing::log::warn; use tracing::{info_span, instrument, Instrument, Span}; /// Queue entry @@ -51,6 +54,8 @@ impl Queue { speculate: u32, max_batch_total_tokens: u32, support_chunking: bool, + engine_state: Arc>, + queue_events: BroadcastSender, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -64,7 +69,9 @@ impl Queue { speculate, max_batch_total_tokens, support_chunking, + engine_state, queue_receiver, + queue_events, )); Self { queue_sender } @@ -113,7 +120,7 @@ impl Queue { } } -// Background task responsible of the queue state +// Background task responsible for the queue state #[allow(clippy::too_many_arguments)] async fn queue_task( requires_padding: bool, @@ -123,7 +130,9 @@ async fn queue_task( speculate: u32, max_batch_total_tokens: u32, support_chunking: bool, + engine_state: Arc>, mut receiver: mpsc::UnboundedReceiver, + queue_events: BroadcastSender, ) { let mut state = State::new( requires_padding, @@ -138,9 +147,29 @@ async fn queue_task( while let Some(cmd) = receiver.recv().await { match cmd { QueueCommand::Append(entry, span) => { - metrics::gauge!("tgi_queue_size").increment(1.0); - metrics::gauge!("tgi_queue_size_tokens").increment(entry.request.input_length); + let entry_num_tokens = entry.request.input_length; span.in_scope(|| state.append(*entry)); + metrics::gauge!("tgi_queue_size").increment(1.0); + metrics::gauge!("tgi_queue_size_tokens").increment(entry_num_tokens); + + // Dispatch new state to the proxy + { + // Lock free operation (read) + let num_queued_tokens = engine_state.read().await.in_queue; + { + // Critical section, doing as little as possible here + let mut engine_state = engine_state.write().await; + engine_state.in_queue = num_queued_tokens + entry_num_tokens; + } + + // Send new state to the channel for broadcasting + if let Err(err) = queue_events.send(*engine_state.read().await) { + tracing::warn!( + "Failed to send BatchEvent::QueueChanged({}): {err}", + num_queued_tokens + entry_num_tokens + ) + } + } } QueueCommand::NextBatch { min_size, @@ -156,14 +185,32 @@ async fn queue_task( .await; response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); - metrics::gauge!("tgi_queue_size_tokens").set( - state + { + let num_batch_tokens = state .entries .iter() - .map(|(_, e)| e.request.input_length as f64) - .sum::(), - ); + .map(|(_, e)| e.request.input_length) + .sum::(); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); + metrics::gauge!("tgi_queue_size_tokens").set(num_batch_tokens as f64); + + // Dispatch new state to the proxy + { + // Critical section, doing as little as possible here + { + let mut engine_state = engine_state.write().await; + engine_state.in_queue = num_batch_tokens; + } + + // Send new state to the channel for broadcasting + if let Err(err) = queue_events.send(*engine_state.read().await) { + tracing::warn!( + "Failed to send BatchEvent::QueueChanged({}): {err}", + num_batch_tokens + ) + } + } + } } } } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index bc958fee..0123f54f 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -31,16 +31,29 @@ use tracing::instrument; #[derive(Debug, Copy, Clone, Serialize)] pub struct EngineState { /// Number of tokens currently participating in current batch - in_flight: u32, + pub in_flight: u32, /// Maximum number of tokens which can participate in a batch - in_flight_max: u32, + pub in_flight_max: u32, /// Number of tokens currently waiting in the queue for future batching - in_queue: u32, + pub in_queue: u32, /// Maximum number of tokens which can wait in the queue for future batching - in_queue_max: u32, + pub in_queue_max: u32, +} + +#[cfg(feature = "engine-state")] +impl EngineState { + #[inline] + pub fn new(in_flight_max: u32, in_queue_max: u32) -> Self { + EngineState { + in_flight: 0, + in_flight_max, + in_queue: 0, + in_queue_max, + } + } } #[async_trait] diff --git a/router/src/server.rs b/router/src/server.rs index b952c62a..c394f7b8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -62,7 +62,6 @@ use tokio::select; use tokio::signal; use tokio::sync::oneshot; use tokio::time::Instant; -use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::BroadcastStream; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; @@ -2398,6 +2397,7 @@ async fn start( .route("/health", get(health)) .route("/ping", get(health)) .route("/metrics", get(metrics)) + .route("/state", get(state)) .route("/v1/models", get(openai_get_model_info)); let compute_type =