From 8de41f63a8bf53b4b99e855753f0b69d82d2490a Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 27 Feb 2025 16:58:02 +0100 Subject: [PATCH] feat(metrics): exposes the engine state as an endpoint --- backends/v3/Cargo.toml | 16 ++++++++-------- backends/v3/src/backend.rs | 14 +++++++++++++- router/Cargo.toml | 19 ++++++++++--------- router/src/infer/mod.rs | 29 +++++++++++++++++++++++++++++ router/src/server.rs | 25 +++++++++++++++++++++++++ 5 files changed, 85 insertions(+), 18 deletions(-) diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 996290ed..1489efb4 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -18,7 +18,7 @@ async-trait = "0.1.74" async-stream = "0.3.5" axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" -text-generation-router = { path = "../../router" } +text-generation-router = { path = "../../router", features = ["engine-state", "ngrok"] } clap = { version = "4.4.5", features = ["derive", "env"] } grpc-metadata = { path = "../grpc-metadata" } futures = "0.3.28" @@ -37,13 +37,13 @@ slotmap = "1.0.7" thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = [ - "rt", - "rt-multi-thread", - "parking_lot", - "signal", - "sync", + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", ] } -tokio-stream = "0.1.14" +tokio-stream = { version = "0.1.14", features = ["sync"] } tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.21.0" @@ -51,7 +51,7 @@ tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } init-tracing-opentelemetry = { version = "0.14.1", features = [ - "opentelemetry-otlp", + "opentelemetry-otlp", ] } minijinja = { workspace = true } minijinja-contrib = { workspace = true } diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 98e8d76f..6b3993bb 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -6,16 +6,23 @@ use crate::queue::{Entry, Queue}; use async_trait::async_trait; use nohash_hasher::IntMap; use std::sync::Arc; -use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::infer::{ + Backend, EngineState, GeneratedText, InferError, InferStreamResponse, +}; use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::{FinishReason, PrefillToken, Token}; +use tokio::sync::broadcast::{channel, Receiver, Sender}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; + pub struct BackendV3 { + /// Events streaming channel + events: (Sender, Receiver), + /// Request queue queue: Queue, /// Notify batcher on queue appends @@ -66,6 +73,7 @@ impl BackendV3 { )); Self { + events: channel(1), queue, batching_task_notifier, client, @@ -112,6 +120,10 @@ impl Backend for BackendV3 { .is_ok() } + fn events(&self) -> Receiver { + self.events.0.subscribe() + } + fn start_health(&self) -> bool { true } diff --git a/router/Cargo.toml b/router/Cargo.toml index 9326258d..bdaefcc7 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -31,11 +31,11 @@ serde_json = "1.0.107" thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = [ - "rt", - "rt-multi-thread", - "parking_lot", - "signal", - "sync", + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", ] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } @@ -46,7 +46,7 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ - "opentelemetry-otlp", + "opentelemetry-otlp", ] } minijinja = { workspace = true, features = ["loop_controls"] } minijinja-contrib = { workspace = true } @@ -57,9 +57,9 @@ image = "0.25.1" base64 = { workspace = true } sysinfo = "0.30.13" uuid = { version = "1.9.1", default-features = false, features = [ - "v4", - "fast-rng", - "macro-diagnostics", + "v4", + "fast-rng", + "macro-diagnostics", ] } csv = "1.3.0" ureq = "=2.9" @@ -73,5 +73,6 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } [features] default = ["ngrok"] ngrok = ["dep:ngrok"] +engine-state = [] google = [] kserve = [] diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 7eb8a41b..bc958fee 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -19,12 +19,30 @@ use serde::Serialize; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use thiserror::Error; +use tokio::sync::broadcast::Receiver; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::instrument; +/// Store real-time information about batch engine usage (expressed in tokens) +#[cfg(feature = "engine-state")] +#[derive(Debug, Copy, Clone, Serialize)] +pub struct EngineState { + /// Number of tokens currently participating in current batch + in_flight: u32, + + /// Maximum number of tokens which can participate in a batch + in_flight_max: u32, + + /// Number of tokens currently waiting in the queue for future batching + in_queue: u32, + + /// Maximum number of tokens which can wait in the queue for future batching + in_queue_max: u32, +} + #[async_trait] pub trait Backend { fn schedule( @@ -34,6 +52,11 @@ pub trait Backend { async fn health(&self, current_health: bool) -> bool; + /// Gets a reference to receiving-side channel generating events about the current internal + /// batching engine state + #[cfg(feature = "engine-state")] + fn events(&self) -> Receiver; + /// The state of the health on startup /// Typically false, or true if the backend includes /// a warmup phase. @@ -95,6 +118,12 @@ impl Infer { } } + #[cfg(feature = "engine-state")] + #[inline] + pub(crate) fn events(&self) -> Receiver { + self.backend.events() + } + /// Add a new request to the queue and return a stream of InferStreamResponse #[instrument(skip_all)] pub(crate) async fn generate_stream<'a>( diff --git a/router/src/server.rs b/router/src/server.rs index 29392771..b952c62a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -62,6 +62,8 @@ use tokio::select; use tokio::signal; use tokio::sync::oneshot; use tokio::time::Instant; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; +use tokio_stream::wrappers::BroadcastStream; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; @@ -1502,6 +1504,28 @@ async fn metrics(prom_handle: Extension) -> String { prom_handle.render() } +#[utoipa::path(get, tag = "Text Generation Inference", path = "/state")] +#[instrument(skip_all)] +async fn state( + Extension(infer): Extension, +) -> Result>>, StatusCode> { + if cfg!(feature = "engine-state") { + let stream = infer.events(); + let sse = + Sse::new(BroadcastStream::from(stream).map(|state| { + Event::default().json_data(state.map_err(|err| axum::Error::new(err))?) + })) + .keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(5)) + .text("more_open_models_on_hf"), + ); + Ok(sse) + } else { + Err(StatusCode::NOT_IMPLEMENTED) + } +} + #[derive(Clone, Debug)] pub(crate) struct ComputeType(String); @@ -1521,6 +1545,7 @@ metrics, openai_get_model_info, sagemaker_compatibility, get_chat_tokenize, +state, ), components( schemas(