This commit is contained in:
Funtowicz Morgan 2025-04-17 17:22:04 +02:00 committed by GitHub
commit 6569c64217
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 200 additions and 23 deletions

1
Cargo.lock generated
View File

@ -5134,6 +5134,7 @@ dependencies = [
"futures-core", "futures-core",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
"tokio-util",
] ]
[[package]] [[package]]

View File

@ -18,7 +18,7 @@ async-trait = "0.1.74"
async-stream = "0.3.5" async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] } axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16" axum-tracing-opentelemetry = "0.16"
text-generation-router = { path = "../../router" } text-generation-router = { path = "../../router", features = ["engine-state"] }
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" } grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28" futures = "0.3.28"
@ -43,7 +43,7 @@ tokio = { version = "1.32.0", features = [
"signal", "signal",
"sync", "sync",
] } ] }
tokio-stream = "0.1.14" tokio-stream = { version = "0.1.14", features = ["sync"] }
tower-http = { version = "0.5.1", features = ["cors"] } tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.37" tracing = "0.1.37"
tracing-opentelemetry = "0.21.0" tracing-opentelemetry = "0.21.0"

View File

@ -6,20 +6,32 @@ use crate::queue::{Entry, Queue};
use async_trait::async_trait; use async_trait::async_trait;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{
Backend, EngineState, GeneratedText, InferError, InferStreamResponse,
};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token}; use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::broadcast::{channel, Receiver as BroadcastReceiver, Sender as BroadcastSender};
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify}; use tokio::sync::{mpsc, Notify, RwLock};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
pub struct BackendV3 { pub struct BackendV3 {
/// Internal batching state exposing info for the proxy
state: Arc<RwLock<EngineState>>,
/// Request queue /// Request queue
queue: Queue, queue: Queue,
/// Events streaming channel
state_events: (BroadcastSender<EngineState>, BroadcastReceiver<EngineState>),
/// Notify batcher on queue appends /// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>, batching_task_notifier: Arc<Notify>,
/// Client clone, used for health checks to skip the queue /// Client clone, used for health checks to skip the queue
client: ShardedClient, client: ShardedClient,
} }
@ -41,6 +53,12 @@ impl BackendV3 {
let block_size = shard_info.block_size; let block_size = shard_info.block_size;
let state_events = channel(1);
let state = Arc::new(RwLock::new(EngineState::new(
max_batch_total_tokens,
2 * max_batch_total_tokens,
)));
let queue = Queue::new( let queue = Queue::new(
shard_info.requires_padding, shard_info.requires_padding,
block_size, block_size,
@ -49,6 +67,8 @@ impl BackendV3 {
shard_info.speculate, shard_info.speculate,
max_batch_total_tokens, max_batch_total_tokens,
shard_info.support_chunking, shard_info.support_chunking,
Arc::clone(&state),
state_events.0.clone(),
); );
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());
@ -62,10 +82,14 @@ impl BackendV3 {
max_batch_size, max_batch_size,
shard_info.support_chunking, shard_info.support_chunking,
queue.clone(), queue.clone(),
state.clone(),
state_events.0.clone(),
batching_task_notifier.clone(), batching_task_notifier.clone(),
)); ));
Self { Self {
state,
state_events,
queue, queue,
batching_task_notifier, batching_task_notifier,
client, client,
@ -112,6 +136,10 @@ impl Backend for BackendV3 {
.is_ok() .is_ok()
} }
fn events(&self) -> BroadcastReceiver<EngineState> {
self.state_events.0.subscribe()
}
fn start_health(&self) -> bool { fn start_health(&self) -> bool {
true true
} }
@ -135,6 +163,8 @@ pub(crate) async fn batching_task(
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
support_chunking: bool, support_chunking: bool,
queue: Queue, queue: Queue,
engine_state: Arc<RwLock<EngineState>>,
batch_events: BroadcastSender<EngineState>,
notifier: Arc<Notify>, notifier: Arc<Notify>,
) { ) {
// Infinite loop // Infinite loop
@ -170,6 +200,22 @@ pub(crate) async fn batching_task(
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
// Dispatch new state to the proxy
{
// Critical section, doing as little as possible here
{
let mut engine_state = engine_state.write().await;
engine_state.in_flight = batch_max_tokens;
}
// Send new state to the channel for broadcasting
if let Err(err) = batch_events.send(*engine_state.read().await) {
tracing::warn!(
"Failed to send BatchEvent::BatchChanged({batch_max_tokens}): {err}"
)
}
}
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let (min_size, max_size, prefill_token_budget) = if support_chunking { let (min_size, max_size, prefill_token_budget) = if support_chunking {

View File

@ -6,14 +6,17 @@ use crate::client::{
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::max; use std::cmp::max;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_router::infer::InferError; use std::sync::Arc;
use text_generation_router::infer::InferStreamResponse; use text_generation_router::infer::InferStreamResponse;
use text_generation_router::infer::{EngineState, InferError};
use text_generation_router::validation::{ use text_generation_router::validation::{
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
ValidStoppingParameters, ValidStoppingParameters,
}; };
use tokio::sync::{mpsc, oneshot}; use tokio::sync::broadcast::Sender as BroadcastSender;
use tokio::sync::{mpsc, oneshot, RwLock};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::log::warn;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
/// Queue entry /// Queue entry
@ -51,6 +54,8 @@ impl Queue {
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
support_chunking: bool, support_chunking: bool,
engine_state: Arc<RwLock<EngineState>>,
queue_events: BroadcastSender<EngineState>,
) -> Self { ) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
@ -64,7 +69,9 @@ impl Queue {
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
support_chunking, support_chunking,
engine_state,
queue_receiver, queue_receiver,
queue_events,
)); ));
Self { queue_sender } Self { queue_sender }
@ -113,7 +120,7 @@ impl Queue {
} }
} }
// Background task responsible of the queue state // Background task responsible for the queue state
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn queue_task( async fn queue_task(
requires_padding: bool, requires_padding: bool,
@ -123,7 +130,9 @@ async fn queue_task(
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
support_chunking: bool, support_chunking: bool,
engine_state: Arc<RwLock<EngineState>>,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>, mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
queue_events: BroadcastSender<EngineState>,
) { ) {
let mut state = State::new( let mut state = State::new(
requires_padding, requires_padding,
@ -138,8 +147,29 @@ async fn queue_task(
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
QueueCommand::Append(entry, span) => { QueueCommand::Append(entry, span) => {
let entry_num_tokens = entry.request.input_length;
span.in_scope(|| state.append(*entry)); span.in_scope(|| state.append(*entry));
metrics::gauge!("tgi_queue_size").increment(1.0); metrics::gauge!("tgi_queue_size").increment(1.0);
metrics::gauge!("tgi_queue_size_tokens").increment(entry_num_tokens);
// Dispatch new state to the proxy
{
// Lock free operation (read)
let num_queued_tokens = engine_state.read().await.in_queue;
{
// Critical section, doing as little as possible here
let mut engine_state = engine_state.write().await;
engine_state.in_queue = num_queued_tokens + entry_num_tokens;
}
// Send new state to the channel for broadcasting
if let Err(err) = queue_events.send(*engine_state.read().await) {
tracing::warn!(
"Failed to send BatchEvent::QueueChanged({}): {err}",
num_queued_tokens + entry_num_tokens
)
}
}
} }
QueueCommand::NextBatch { QueueCommand::NextBatch {
min_size, min_size,
@ -154,7 +184,33 @@ async fn queue_task(
.instrument(span) .instrument(span)
.await; .await;
response_sender.send(next_batch).unwrap(); response_sender.send(next_batch).unwrap();
{
let num_batch_tokens = state
.entries
.iter()
.map(|(_, e)| e.request.input_length)
.sum::<u32>();
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
metrics::gauge!("tgi_queue_size_tokens").set(num_batch_tokens as f64);
// Dispatch new state to the proxy
{
// Critical section, doing as little as possible here
{
let mut engine_state = engine_state.write().await;
engine_state.in_queue = num_batch_tokens;
}
// Send new state to the channel for broadcasting
if let Err(err) = queue_events.send(*engine_state.read().await) {
tracing::warn!(
"Failed to send BatchEvent::QueueChanged({}): {err}",
num_batch_tokens
)
}
}
}
} }
} }
} }

View File

@ -73,5 +73,6 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
[features] [features]
default = ["ngrok"] default = ["ngrok"]
ngrok = ["dep:ngrok"] ngrok = ["dep:ngrok"]
engine-state = []
google = [] google = []
kserve = [] kserve = []

View File

@ -19,12 +19,43 @@ use serde::Serialize;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokio::sync::broadcast::Receiver;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::instrument;
/// Store real-time information about batch engine usage (expressed in tokens)
#[cfg(feature = "engine-state")]
#[derive(Debug, Copy, Clone, Serialize)]
pub struct EngineState {
/// Number of tokens currently participating in current batch
pub in_flight: u32,
/// Maximum number of tokens which can participate in a batch
pub in_flight_max: u32,
/// Number of tokens currently waiting in the queue for future batching
pub in_queue: u32,
/// Maximum number of tokens which can wait in the queue for future batching
pub in_queue_max: u32,
}
#[cfg(feature = "engine-state")]
impl EngineState {
#[inline]
pub fn new(in_flight_max: u32, in_queue_max: u32) -> Self {
EngineState {
in_flight: 0,
in_flight_max,
in_queue: 0,
in_queue_max,
}
}
}
#[async_trait] #[async_trait]
pub trait Backend { pub trait Backend {
fn schedule( fn schedule(
@ -34,6 +65,11 @@ pub trait Backend {
async fn health(&self, current_health: bool) -> bool; async fn health(&self, current_health: bool) -> bool;
/// Gets a reference to receiving-side channel generating events about the current internal
/// batching engine state
#[cfg(feature = "engine-state")]
fn events(&self) -> Receiver<EngineState>;
/// The state of the health on startup /// The state of the health on startup
/// Typically false, or true if the backend includes /// Typically false, or true if the backend includes
/// a warmup phase. /// a warmup phase.
@ -95,6 +131,12 @@ impl Infer {
} }
} }
#[cfg(feature = "engine-state")]
#[inline]
pub(crate) fn events(&self) -> Receiver<EngineState> {
self.backend.events()
}
/// Add a new request to the queue and return a stream of InferStreamResponse /// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) async fn generate_stream<'a>( pub(crate) async fn generate_stream<'a>(

View File

@ -61,11 +61,13 @@ use tokio::select;
use tokio::signal; use tokio::signal;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::BroadcastStream;
use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec<SimpleToken> { fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec<SimpleToken> {
let offsets = encoding.get_offsets(); let offsets = encoding.get_offsets();
let input_ids = encoding.get_ids(); let input_ids = encoding.get_ids();
@ -1309,6 +1311,28 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render() prom_handle.render()
} }
#[utoipa::path(get, tag = "Text Generation Inference", path = "/state")]
#[instrument(skip_all)]
async fn state(
Extension(infer): Extension<Infer>,
) -> Result<Sse<impl Stream<Item = Result<Event, axum::Error>>>, StatusCode> {
if cfg!(feature = "engine-state") {
let stream = infer.events();
let sse =
Sse::new(BroadcastStream::from(stream).map(|state| {
Event::default().json_data(state.map_err(|err| axum::Error::new(err))?)
}))
.keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(5))
.text("more_open_models_on_hf"),
);
Ok(sse)
} else {
Err(StatusCode::NOT_IMPLEMENTED)
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub(crate) struct ComputeType(String); pub(crate) struct ComputeType(String);
@ -1328,6 +1352,7 @@ metrics,
openai_get_model_info, openai_get_model_info,
sagemaker_compatibility, sagemaker_compatibility,
get_chat_tokenize, get_chat_tokenize,
state,
), ),
components( components(
schemas( schemas(
@ -2007,6 +2032,11 @@ async fn start(
"Current batch size" "Current batch size"
); );
metrics::describe_gauge!("tgi_queue_size", metrics::Unit::Count, "Current queue size"); metrics::describe_gauge!("tgi_queue_size", metrics::Unit::Count, "Current queue size");
metrics::describe_gauge!(
"tgi_queue_size_tokens",
metrics::Unit::Count,
"Current queue size in number of tokens"
);
metrics::describe_gauge!( metrics::describe_gauge!(
"tgi_batch_current_max_tokens", "tgi_batch_current_max_tokens",
metrics::Unit::Count, metrics::Unit::Count,
@ -2203,6 +2233,7 @@ async fn start(
.route("/health", get(health)) .route("/health", get(health))
.route("/ping", get(health)) .route("/ping", get(health))
.route("/metrics", get(metrics)) .route("/metrics", get(metrics))
.route("/state", get(state))
.route("/v1/models", get(openai_get_model_info)); .route("/v1/models", get(openai_get_model_info));
let compute_type = let compute_type =