mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
feat(metrics): exposes the engine state as an endpoint
This commit is contained in:
parent
bb8f59632f
commit
8de41f63a8
@ -18,7 +18,7 @@ async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.16"
|
||||
text-generation-router = { path = "../../router" }
|
||||
text-generation-router = { path = "../../router", features = ["engine-state", "ngrok"] }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
futures = "0.3.28"
|
||||
@ -37,13 +37,13 @@ slotmap = "1.0.7"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = [
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"parking_lot",
|
||||
"signal",
|
||||
"sync",
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"parking_lot",
|
||||
"signal",
|
||||
"sync",
|
||||
] }
|
||||
tokio-stream = "0.1.14"
|
||||
tokio-stream = { version = "0.1.14", features = ["sync"] }
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.21.0"
|
||||
@ -51,7 +51,7 @@ tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||
"opentelemetry-otlp",
|
||||
"opentelemetry-otlp",
|
||||
] }
|
||||
minijinja = { workspace = true }
|
||||
minijinja-contrib = { workspace = true }
|
||||
|
@ -6,16 +6,23 @@ use crate::queue::{Entry, Queue};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::infer::{
|
||||
Backend, EngineState, GeneratedText, InferError, InferStreamResponse,
|
||||
};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::broadcast::{channel, Receiver, Sender};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
|
||||
pub struct BackendV3 {
|
||||
/// Events streaming channel
|
||||
events: (Sender<EngineState>, Receiver<EngineState>),
|
||||
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
@ -66,6 +73,7 @@ impl BackendV3 {
|
||||
));
|
||||
|
||||
Self {
|
||||
events: channel(1),
|
||||
queue,
|
||||
batching_task_notifier,
|
||||
client,
|
||||
@ -112,6 +120,10 @@ impl Backend for BackendV3 {
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
fn events(&self) -> Receiver<EngineState> {
|
||||
self.events.0.subscribe()
|
||||
}
|
||||
|
||||
fn start_health(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
@ -31,11 +31,11 @@ serde_json = "1.0.107"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = [
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"parking_lot",
|
||||
"signal",
|
||||
"sync",
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"parking_lot",
|
||||
"signal",
|
||||
"sync",
|
||||
] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
@ -46,7 +46,7 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||
"opentelemetry-otlp",
|
||||
"opentelemetry-otlp",
|
||||
] }
|
||||
minijinja = { workspace = true, features = ["loop_controls"] }
|
||||
minijinja-contrib = { workspace = true }
|
||||
@ -57,9 +57,9 @@ image = "0.25.1"
|
||||
base64 = { workspace = true }
|
||||
sysinfo = "0.30.13"
|
||||
uuid = { version = "1.9.1", default-features = false, features = [
|
||||
"v4",
|
||||
"fast-rng",
|
||||
"macro-diagnostics",
|
||||
"v4",
|
||||
"fast-rng",
|
||||
"macro-diagnostics",
|
||||
] }
|
||||
csv = "1.3.0"
|
||||
ureq = "=2.9"
|
||||
@ -73,5 +73,6 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||
[features]
|
||||
default = ["ngrok"]
|
||||
ngrok = ["dep:ngrok"]
|
||||
engine-state = []
|
||||
google = []
|
||||
kserve = []
|
||||
|
@ -19,12 +19,30 @@ use serde::Serialize;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::broadcast::Receiver;
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::instrument;
|
||||
|
||||
/// Store real-time information about batch engine usage (expressed in tokens)
|
||||
#[cfg(feature = "engine-state")]
|
||||
#[derive(Debug, Copy, Clone, Serialize)]
|
||||
pub struct EngineState {
|
||||
/// Number of tokens currently participating in current batch
|
||||
in_flight: u32,
|
||||
|
||||
/// Maximum number of tokens which can participate in a batch
|
||||
in_flight_max: u32,
|
||||
|
||||
/// Number of tokens currently waiting in the queue for future batching
|
||||
in_queue: u32,
|
||||
|
||||
/// Maximum number of tokens which can wait in the queue for future batching
|
||||
in_queue_max: u32,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Backend {
|
||||
fn schedule(
|
||||
@ -34,6 +52,11 @@ pub trait Backend {
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool;
|
||||
|
||||
/// Gets a reference to receiving-side channel generating events about the current internal
|
||||
/// batching engine state
|
||||
#[cfg(feature = "engine-state")]
|
||||
fn events(&self) -> Receiver<EngineState>;
|
||||
|
||||
/// The state of the health on startup
|
||||
/// Typically false, or true if the backend includes
|
||||
/// a warmup phase.
|
||||
@ -95,6 +118,12 @@ impl Infer {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "engine-state")]
|
||||
#[inline]
|
||||
pub(crate) fn events(&self) -> Receiver<EngineState> {
|
||||
self.backend.events()
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn generate_stream<'a>(
|
||||
|
@ -62,6 +62,8 @@ use tokio::select;
|
||||
use tokio::signal;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
use utoipa::OpenApi;
|
||||
@ -1502,6 +1504,28 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||
prom_handle.render()
|
||||
}
|
||||
|
||||
#[utoipa::path(get, tag = "Text Generation Inference", path = "/state")]
|
||||
#[instrument(skip_all)]
|
||||
async fn state(
|
||||
Extension(infer): Extension<Infer>,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, axum::Error>>>, StatusCode> {
|
||||
if cfg!(feature = "engine-state") {
|
||||
let stream = infer.events();
|
||||
let sse =
|
||||
Sse::new(BroadcastStream::from(stream).map(|state| {
|
||||
Event::default().json_data(state.map_err(|err| axum::Error::new(err))?)
|
||||
}))
|
||||
.keep_alive(
|
||||
KeepAlive::new()
|
||||
.interval(Duration::from_secs(5))
|
||||
.text("more_open_models_on_hf"),
|
||||
);
|
||||
Ok(sse)
|
||||
} else {
|
||||
Err(StatusCode::NOT_IMPLEMENTED)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ComputeType(String);
|
||||
|
||||
@ -1521,6 +1545,7 @@ metrics,
|
||||
openai_get_model_info,
|
||||
sagemaker_compatibility,
|
||||
get_chat_tokenize,
|
||||
state,
|
||||
),
|
||||
components(
|
||||
schemas(
|
||||
|
Loading…
Reference in New Issue
Block a user