mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
feat(metrics): dispatch internal engine state event from queuing/batching tasks
This commit is contained in:
parent
1a9c5dec76
commit
712199c769
@ -11,22 +11,27 @@ use text_generation_router::infer::{
|
||||
};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::broadcast::{channel, Receiver, Sender};
|
||||
use tokio::sync::broadcast::{channel, Receiver as BroadcastReceiver, Sender as BroadcastSender};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::sync::{mpsc, Notify, RwLock};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
|
||||
pub struct BackendV3 {
|
||||
/// Events streaming channel
|
||||
events: (Sender<EngineState>, Receiver<EngineState>),
|
||||
/// Internal batching state exposing info for the proxy
|
||||
state: Arc<RwLock<EngineState>>,
|
||||
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
|
||||
/// Events streaming channel
|
||||
state_events: (BroadcastSender<EngineState>, BroadcastReceiver<EngineState>),
|
||||
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
|
||||
/// Client clone, used for health checks to skip the queue
|
||||
client: ShardedClient,
|
||||
}
|
||||
@ -48,6 +53,12 @@ impl BackendV3 {
|
||||
|
||||
let block_size = shard_info.block_size;
|
||||
|
||||
let state_events = channel(1);
|
||||
let state = Arc::new(RwLock::new(EngineState::new(
|
||||
max_batch_total_tokens,
|
||||
2 * max_batch_total_tokens,
|
||||
)));
|
||||
|
||||
let queue = Queue::new(
|
||||
shard_info.requires_padding,
|
||||
block_size,
|
||||
@ -56,6 +67,8 @@ impl BackendV3 {
|
||||
shard_info.speculate,
|
||||
max_batch_total_tokens,
|
||||
shard_info.support_chunking,
|
||||
Arc::clone(&state),
|
||||
state_events.0.clone(),
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
@ -69,11 +82,14 @@ impl BackendV3 {
|
||||
max_batch_size,
|
||||
shard_info.support_chunking,
|
||||
queue.clone(),
|
||||
state.clone(),
|
||||
state_events.0.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
));
|
||||
|
||||
Self {
|
||||
events: channel(1),
|
||||
state,
|
||||
state_events,
|
||||
queue,
|
||||
batching_task_notifier,
|
||||
client,
|
||||
@ -120,8 +136,8 @@ impl Backend for BackendV3 {
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
fn events(&self) -> Receiver<EngineState> {
|
||||
self.events.0.subscribe()
|
||||
fn events(&self) -> BroadcastReceiver<EngineState> {
|
||||
self.state_events.0.subscribe()
|
||||
}
|
||||
|
||||
fn start_health(&self) -> bool {
|
||||
@ -147,6 +163,8 @@ pub(crate) async fn batching_task(
|
||||
max_batch_size: Option<usize>,
|
||||
support_chunking: bool,
|
||||
queue: Queue,
|
||||
engine_state: Arc<RwLock<EngineState>>,
|
||||
batch_events: BroadcastSender<EngineState>,
|
||||
notifier: Arc<Notify>,
|
||||
) {
|
||||
// Infinite loop
|
||||
@ -182,6 +200,22 @@ pub(crate) async fn batching_task(
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
// Dispatch new state to the proxy
|
||||
{
|
||||
// Critical section, doing as little as possible here
|
||||
{
|
||||
let mut engine_state = engine_state.write().await;
|
||||
engine_state.in_flight = batch_max_tokens;
|
||||
}
|
||||
|
||||
// Send new state to the channel for broadcasting
|
||||
if let Err(err) = batch_events.send(*engine_state.read().await) {
|
||||
tracing::warn!(
|
||||
"Failed to send BatchEvent::BatchChanged({batch_max_tokens}): {err}"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
|
||||
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
||||
|
@ -6,14 +6,17 @@ use crate::client::{
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::max;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_router::infer::InferError;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::infer::{EngineState, InferError};
|
||||
use text_generation_router::validation::{
|
||||
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||
ValidStoppingParameters,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::sync::broadcast::Sender as BroadcastSender;
|
||||
use tokio::sync::{mpsc, oneshot, RwLock};
|
||||
use tokio::time::Instant;
|
||||
use tracing::log::warn;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
/// Queue entry
|
||||
@ -51,6 +54,8 @@ impl Queue {
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
support_chunking: bool,
|
||||
engine_state: Arc<RwLock<EngineState>>,
|
||||
queue_events: BroadcastSender<EngineState>,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
@ -64,7 +69,9 @@ impl Queue {
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
support_chunking,
|
||||
engine_state,
|
||||
queue_receiver,
|
||||
queue_events,
|
||||
));
|
||||
|
||||
Self { queue_sender }
|
||||
@ -113,7 +120,7 @@ impl Queue {
|
||||
}
|
||||
}
|
||||
|
||||
// Background task responsible of the queue state
|
||||
// Background task responsible for the queue state
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
@ -123,7 +130,9 @@ async fn queue_task(
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
support_chunking: bool,
|
||||
engine_state: Arc<RwLock<EngineState>>,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
queue_events: BroadcastSender<EngineState>,
|
||||
) {
|
||||
let mut state = State::new(
|
||||
requires_padding,
|
||||
@ -138,9 +147,29 @@ async fn queue_task(
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
QueueCommand::Append(entry, span) => {
|
||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||
metrics::gauge!("tgi_queue_size_tokens").increment(entry.request.input_length);
|
||||
let entry_num_tokens = entry.request.input_length;
|
||||
span.in_scope(|| state.append(*entry));
|
||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||
metrics::gauge!("tgi_queue_size_tokens").increment(entry_num_tokens);
|
||||
|
||||
// Dispatch new state to the proxy
|
||||
{
|
||||
// Lock free operation (read)
|
||||
let num_queued_tokens = engine_state.read().await.in_queue;
|
||||
{
|
||||
// Critical section, doing as little as possible here
|
||||
let mut engine_state = engine_state.write().await;
|
||||
engine_state.in_queue = num_queued_tokens + entry_num_tokens;
|
||||
}
|
||||
|
||||
// Send new state to the channel for broadcasting
|
||||
if let Err(err) = queue_events.send(*engine_state.read().await) {
|
||||
tracing::warn!(
|
||||
"Failed to send BatchEvent::QueueChanged({}): {err}",
|
||||
num_queued_tokens + entry_num_tokens
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
@ -156,14 +185,32 @@ async fn queue_task(
|
||||
.await;
|
||||
response_sender.send(next_batch).unwrap();
|
||||
|
||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||
metrics::gauge!("tgi_queue_size_tokens").set(
|
||||
state
|
||||
{
|
||||
let num_batch_tokens = state
|
||||
.entries
|
||||
.iter()
|
||||
.map(|(_, e)| e.request.input_length as f64)
|
||||
.sum::<f64>(),
|
||||
);
|
||||
.map(|(_, e)| e.request.input_length)
|
||||
.sum::<u32>();
|
||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||
metrics::gauge!("tgi_queue_size_tokens").set(num_batch_tokens as f64);
|
||||
|
||||
// Dispatch new state to the proxy
|
||||
{
|
||||
// Critical section, doing as little as possible here
|
||||
{
|
||||
let mut engine_state = engine_state.write().await;
|
||||
engine_state.in_queue = num_batch_tokens;
|
||||
}
|
||||
|
||||
// Send new state to the channel for broadcasting
|
||||
if let Err(err) = queue_events.send(*engine_state.read().await) {
|
||||
tracing::warn!(
|
||||
"Failed to send BatchEvent::QueueChanged({}): {err}",
|
||||
num_batch_tokens
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,16 +31,29 @@ use tracing::instrument;
|
||||
#[derive(Debug, Copy, Clone, Serialize)]
|
||||
pub struct EngineState {
|
||||
/// Number of tokens currently participating in current batch
|
||||
in_flight: u32,
|
||||
pub in_flight: u32,
|
||||
|
||||
/// Maximum number of tokens which can participate in a batch
|
||||
in_flight_max: u32,
|
||||
pub in_flight_max: u32,
|
||||
|
||||
/// Number of tokens currently waiting in the queue for future batching
|
||||
in_queue: u32,
|
||||
pub in_queue: u32,
|
||||
|
||||
/// Maximum number of tokens which can wait in the queue for future batching
|
||||
in_queue_max: u32,
|
||||
pub in_queue_max: u32,
|
||||
}
|
||||
|
||||
#[cfg(feature = "engine-state")]
|
||||
impl EngineState {
|
||||
#[inline]
|
||||
pub fn new(in_flight_max: u32, in_queue_max: u32) -> Self {
|
||||
EngineState {
|
||||
in_flight: 0,
|
||||
in_flight_max,
|
||||
in_queue: 0,
|
||||
in_queue_max,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
@ -62,7 +62,6 @@ use tokio::select;
|
||||
use tokio::signal;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
@ -2398,6 +2397,7 @@ async fn start(
|
||||
.route("/health", get(health))
|
||||
.route("/ping", get(health))
|
||||
.route("/metrics", get(metrics))
|
||||
.route("/state", get(state))
|
||||
.route("/v1/models", get(openai_get_model_info));
|
||||
|
||||
let compute_type =
|
||||
|
Loading…
Reference in New Issue
Block a user