mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Merge f72547c9fb
into 4645678ff0
This commit is contained in:
commit
6569c64217
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -5134,6 +5134,7 @@ dependencies = [
|
|||||||
"futures-core",
|
"futures-core",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -18,7 +18,7 @@ async-trait = "0.1.74"
|
|||||||
async-stream = "0.3.5"
|
async-stream = "0.3.5"
|
||||||
axum = { version = "0.7", features = ["json"] }
|
axum = { version = "0.7", features = ["json"] }
|
||||||
axum-tracing-opentelemetry = "0.16"
|
axum-tracing-opentelemetry = "0.16"
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router", features = ["engine-state"] }
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
grpc-metadata = { path = "../grpc-metadata" }
|
grpc-metadata = { path = "../grpc-metadata" }
|
||||||
futures = "0.3.28"
|
futures = "0.3.28"
|
||||||
@ -37,13 +37,13 @@ slotmap = "1.0.7"
|
|||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.32.0", features = [
|
tokio = { version = "1.32.0", features = [
|
||||||
"rt",
|
"rt",
|
||||||
"rt-multi-thread",
|
"rt-multi-thread",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"signal",
|
"signal",
|
||||||
"sync",
|
"sync",
|
||||||
] }
|
] }
|
||||||
tokio-stream = "0.1.14"
|
tokio-stream = { version = "0.1.14", features = ["sync"] }
|
||||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-opentelemetry = "0.21.0"
|
tracing-opentelemetry = "0.21.0"
|
||||||
@ -51,7 +51,7 @@ tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
|||||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
] }
|
] }
|
||||||
minijinja = { workspace = true }
|
minijinja = { workspace = true }
|
||||||
minijinja-contrib = { workspace = true }
|
minijinja-contrib = { workspace = true }
|
||||||
|
@ -6,20 +6,32 @@ use crate::queue::{Entry, Queue};
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{
|
||||||
|
Backend, EngineState, GeneratedText, InferError, InferStreamResponse,
|
||||||
|
};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||||
|
use tokio::sync::broadcast::{channel, Receiver as BroadcastReceiver, Sender as BroadcastSender};
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify};
|
use tokio::sync::{mpsc, Notify, RwLock};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
|
|
||||||
pub struct BackendV3 {
|
pub struct BackendV3 {
|
||||||
|
/// Internal batching state exposing info for the proxy
|
||||||
|
state: Arc<RwLock<EngineState>>,
|
||||||
|
|
||||||
/// Request queue
|
/// Request queue
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
|
|
||||||
|
/// Events streaming channel
|
||||||
|
state_events: (BroadcastSender<EngineState>, BroadcastReceiver<EngineState>),
|
||||||
|
|
||||||
/// Notify batcher on queue appends
|
/// Notify batcher on queue appends
|
||||||
batching_task_notifier: Arc<Notify>,
|
batching_task_notifier: Arc<Notify>,
|
||||||
|
|
||||||
/// Client clone, used for health checks to skip the queue
|
/// Client clone, used for health checks to skip the queue
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
}
|
}
|
||||||
@ -41,6 +53,12 @@ impl BackendV3 {
|
|||||||
|
|
||||||
let block_size = shard_info.block_size;
|
let block_size = shard_info.block_size;
|
||||||
|
|
||||||
|
let state_events = channel(1);
|
||||||
|
let state = Arc::new(RwLock::new(EngineState::new(
|
||||||
|
max_batch_total_tokens,
|
||||||
|
2 * max_batch_total_tokens,
|
||||||
|
)));
|
||||||
|
|
||||||
let queue = Queue::new(
|
let queue = Queue::new(
|
||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
@ -49,6 +67,8 @@ impl BackendV3 {
|
|||||||
shard_info.speculate,
|
shard_info.speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
shard_info.support_chunking,
|
shard_info.support_chunking,
|
||||||
|
Arc::clone(&state),
|
||||||
|
state_events.0.clone(),
|
||||||
);
|
);
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
@ -62,10 +82,14 @@ impl BackendV3 {
|
|||||||
max_batch_size,
|
max_batch_size,
|
||||||
shard_info.support_chunking,
|
shard_info.support_chunking,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
|
state.clone(),
|
||||||
|
state_events.0.clone(),
|
||||||
batching_task_notifier.clone(),
|
batching_task_notifier.clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
state,
|
||||||
|
state_events,
|
||||||
queue,
|
queue,
|
||||||
batching_task_notifier,
|
batching_task_notifier,
|
||||||
client,
|
client,
|
||||||
@ -112,6 +136,10 @@ impl Backend for BackendV3 {
|
|||||||
.is_ok()
|
.is_ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn events(&self) -> BroadcastReceiver<EngineState> {
|
||||||
|
self.state_events.0.subscribe()
|
||||||
|
}
|
||||||
|
|
||||||
fn start_health(&self) -> bool {
|
fn start_health(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
@ -135,6 +163,8 @@ pub(crate) async fn batching_task(
|
|||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
support_chunking: bool,
|
support_chunking: bool,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
|
engine_state: Arc<RwLock<EngineState>>,
|
||||||
|
batch_events: BroadcastSender<EngineState>,
|
||||||
notifier: Arc<Notify>,
|
notifier: Arc<Notify>,
|
||||||
) {
|
) {
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
@ -170,6 +200,22 @@ pub(crate) async fn batching_task(
|
|||||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||||
|
|
||||||
|
// Dispatch new state to the proxy
|
||||||
|
{
|
||||||
|
// Critical section, doing as little as possible here
|
||||||
|
{
|
||||||
|
let mut engine_state = engine_state.write().await;
|
||||||
|
engine_state.in_flight = batch_max_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send new state to the channel for broadcasting
|
||||||
|
if let Err(err) = batch_events.send(*engine_state.read().await) {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to send BatchEvent::BatchChanged({batch_max_tokens}): {err}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
|
||||||
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
||||||
|
@ -6,14 +6,17 @@ use crate::client::{
|
|||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::max;
|
use std::cmp::max;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_router::infer::InferError;
|
use std::sync::Arc;
|
||||||
use text_generation_router::infer::InferStreamResponse;
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
|
use text_generation_router::infer::{EngineState, InferError};
|
||||||
use text_generation_router::validation::{
|
use text_generation_router::validation::{
|
||||||
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||||
ValidStoppingParameters,
|
ValidStoppingParameters,
|
||||||
};
|
};
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::broadcast::Sender as BroadcastSender;
|
||||||
|
use tokio::sync::{mpsc, oneshot, RwLock};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
use tracing::log::warn;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
/// Queue entry
|
/// Queue entry
|
||||||
@ -51,6 +54,8 @@ impl Queue {
|
|||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
support_chunking: bool,
|
support_chunking: bool,
|
||||||
|
engine_state: Arc<RwLock<EngineState>>,
|
||||||
|
queue_events: BroadcastSender<EngineState>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
@ -64,7 +69,9 @@ impl Queue {
|
|||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
support_chunking,
|
support_chunking,
|
||||||
|
engine_state,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
|
queue_events,
|
||||||
));
|
));
|
||||||
|
|
||||||
Self { queue_sender }
|
Self { queue_sender }
|
||||||
@ -113,7 +120,7 @@ impl Queue {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Background task responsible of the queue state
|
// Background task responsible for the queue state
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn queue_task(
|
async fn queue_task(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
@ -123,7 +130,9 @@ async fn queue_task(
|
|||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
support_chunking: bool,
|
support_chunking: bool,
|
||||||
|
engine_state: Arc<RwLock<EngineState>>,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
|
queue_events: BroadcastSender<EngineState>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(
|
let mut state = State::new(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
@ -138,8 +147,29 @@ async fn queue_task(
|
|||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
QueueCommand::Append(entry, span) => {
|
QueueCommand::Append(entry, span) => {
|
||||||
|
let entry_num_tokens = entry.request.input_length;
|
||||||
span.in_scope(|| state.append(*entry));
|
span.in_scope(|| state.append(*entry));
|
||||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||||
|
metrics::gauge!("tgi_queue_size_tokens").increment(entry_num_tokens);
|
||||||
|
|
||||||
|
// Dispatch new state to the proxy
|
||||||
|
{
|
||||||
|
// Lock free operation (read)
|
||||||
|
let num_queued_tokens = engine_state.read().await.in_queue;
|
||||||
|
{
|
||||||
|
// Critical section, doing as little as possible here
|
||||||
|
let mut engine_state = engine_state.write().await;
|
||||||
|
engine_state.in_queue = num_queued_tokens + entry_num_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send new state to the channel for broadcasting
|
||||||
|
if let Err(err) = queue_events.send(*engine_state.read().await) {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to send BatchEvent::QueueChanged({}): {err}",
|
||||||
|
num_queued_tokens + entry_num_tokens
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
QueueCommand::NextBatch {
|
QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
@ -154,7 +184,33 @@ async fn queue_task(
|
|||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
response_sender.send(next_batch).unwrap();
|
response_sender.send(next_batch).unwrap();
|
||||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
|
||||||
|
{
|
||||||
|
let num_batch_tokens = state
|
||||||
|
.entries
|
||||||
|
.iter()
|
||||||
|
.map(|(_, e)| e.request.input_length)
|
||||||
|
.sum::<u32>();
|
||||||
|
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||||
|
metrics::gauge!("tgi_queue_size_tokens").set(num_batch_tokens as f64);
|
||||||
|
|
||||||
|
// Dispatch new state to the proxy
|
||||||
|
{
|
||||||
|
// Critical section, doing as little as possible here
|
||||||
|
{
|
||||||
|
let mut engine_state = engine_state.write().await;
|
||||||
|
engine_state.in_queue = num_batch_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send new state to the channel for broadcasting
|
||||||
|
if let Err(err) = queue_events.send(*engine_state.read().await) {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to send BatchEvent::QueueChanged({}): {err}",
|
||||||
|
num_batch_tokens
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -31,11 +31,11 @@ serde_json = "1.0.107"
|
|||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.32.0", features = [
|
tokio = { version = "1.32.0", features = [
|
||||||
"rt",
|
"rt",
|
||||||
"rt-multi-thread",
|
"rt-multi-thread",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"signal",
|
"signal",
|
||||||
"sync",
|
"sync",
|
||||||
] }
|
] }
|
||||||
tokio-stream = "0.1.14"
|
tokio-stream = "0.1.14"
|
||||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||||
@ -46,7 +46,7 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
|||||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
] }
|
] }
|
||||||
minijinja = { workspace = true, features = ["loop_controls"] }
|
minijinja = { workspace = true, features = ["loop_controls"] }
|
||||||
minijinja-contrib = { workspace = true }
|
minijinja-contrib = { workspace = true }
|
||||||
@ -57,9 +57,9 @@ image = "0.25.1"
|
|||||||
base64 = { workspace = true }
|
base64 = { workspace = true }
|
||||||
sysinfo = "0.30.13"
|
sysinfo = "0.30.13"
|
||||||
uuid = { version = "1.9.1", default-features = false, features = [
|
uuid = { version = "1.9.1", default-features = false, features = [
|
||||||
"v4",
|
"v4",
|
||||||
"fast-rng",
|
"fast-rng",
|
||||||
"macro-diagnostics",
|
"macro-diagnostics",
|
||||||
] }
|
] }
|
||||||
csv = "1.3.0"
|
csv = "1.3.0"
|
||||||
ureq = "=2.9"
|
ureq = "=2.9"
|
||||||
@ -73,5 +73,6 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
|||||||
[features]
|
[features]
|
||||||
default = ["ngrok"]
|
default = ["ngrok"]
|
||||||
ngrok = ["dep:ngrok"]
|
ngrok = ["dep:ngrok"]
|
||||||
|
engine-state = []
|
||||||
google = []
|
google = []
|
||||||
kserve = []
|
kserve = []
|
||||||
|
@ -19,12 +19,43 @@ use serde::Serialize;
|
|||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tokio::sync::broadcast::Receiver;
|
||||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
|
/// Store real-time information about batch engine usage (expressed in tokens)
|
||||||
|
#[cfg(feature = "engine-state")]
|
||||||
|
#[derive(Debug, Copy, Clone, Serialize)]
|
||||||
|
pub struct EngineState {
|
||||||
|
/// Number of tokens currently participating in current batch
|
||||||
|
pub in_flight: u32,
|
||||||
|
|
||||||
|
/// Maximum number of tokens which can participate in a batch
|
||||||
|
pub in_flight_max: u32,
|
||||||
|
|
||||||
|
/// Number of tokens currently waiting in the queue for future batching
|
||||||
|
pub in_queue: u32,
|
||||||
|
|
||||||
|
/// Maximum number of tokens which can wait in the queue for future batching
|
||||||
|
pub in_queue_max: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "engine-state")]
|
||||||
|
impl EngineState {
|
||||||
|
#[inline]
|
||||||
|
pub fn new(in_flight_max: u32, in_queue_max: u32) -> Self {
|
||||||
|
EngineState {
|
||||||
|
in_flight: 0,
|
||||||
|
in_flight_max,
|
||||||
|
in_queue: 0,
|
||||||
|
in_queue_max,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Backend {
|
pub trait Backend {
|
||||||
fn schedule(
|
fn schedule(
|
||||||
@ -34,6 +65,11 @@ pub trait Backend {
|
|||||||
|
|
||||||
async fn health(&self, current_health: bool) -> bool;
|
async fn health(&self, current_health: bool) -> bool;
|
||||||
|
|
||||||
|
/// Gets a reference to receiving-side channel generating events about the current internal
|
||||||
|
/// batching engine state
|
||||||
|
#[cfg(feature = "engine-state")]
|
||||||
|
fn events(&self) -> Receiver<EngineState>;
|
||||||
|
|
||||||
/// The state of the health on startup
|
/// The state of the health on startup
|
||||||
/// Typically false, or true if the backend includes
|
/// Typically false, or true if the backend includes
|
||||||
/// a warmup phase.
|
/// a warmup phase.
|
||||||
@ -95,6 +131,12 @@ impl Infer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "engine-state")]
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn events(&self) -> Receiver<EngineState> {
|
||||||
|
self.backend.events()
|
||||||
|
}
|
||||||
|
|
||||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) async fn generate_stream<'a>(
|
pub(crate) async fn generate_stream<'a>(
|
||||||
|
@ -61,11 +61,13 @@ use tokio::select;
|
|||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::wrappers::BroadcastStream;
|
||||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||||
use tracing::{info_span, instrument, Instrument};
|
use tracing::{info_span, instrument, Instrument};
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
|
|
||||||
fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec<SimpleToken> {
|
fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec<SimpleToken> {
|
||||||
let offsets = encoding.get_offsets();
|
let offsets = encoding.get_offsets();
|
||||||
let input_ids = encoding.get_ids();
|
let input_ids = encoding.get_ids();
|
||||||
@ -1309,6 +1311,28 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
|||||||
prom_handle.render()
|
prom_handle.render()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(get, tag = "Text Generation Inference", path = "/state")]
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn state(
|
||||||
|
Extension(infer): Extension<Infer>,
|
||||||
|
) -> Result<Sse<impl Stream<Item = Result<Event, axum::Error>>>, StatusCode> {
|
||||||
|
if cfg!(feature = "engine-state") {
|
||||||
|
let stream = infer.events();
|
||||||
|
let sse =
|
||||||
|
Sse::new(BroadcastStream::from(stream).map(|state| {
|
||||||
|
Event::default().json_data(state.map_err(|err| axum::Error::new(err))?)
|
||||||
|
}))
|
||||||
|
.keep_alive(
|
||||||
|
KeepAlive::new()
|
||||||
|
.interval(Duration::from_secs(5))
|
||||||
|
.text("more_open_models_on_hf"),
|
||||||
|
);
|
||||||
|
Ok(sse)
|
||||||
|
} else {
|
||||||
|
Err(StatusCode::NOT_IMPLEMENTED)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub(crate) struct ComputeType(String);
|
pub(crate) struct ComputeType(String);
|
||||||
|
|
||||||
@ -1328,6 +1352,7 @@ metrics,
|
|||||||
openai_get_model_info,
|
openai_get_model_info,
|
||||||
sagemaker_compatibility,
|
sagemaker_compatibility,
|
||||||
get_chat_tokenize,
|
get_chat_tokenize,
|
||||||
|
state,
|
||||||
),
|
),
|
||||||
components(
|
components(
|
||||||
schemas(
|
schemas(
|
||||||
@ -2007,6 +2032,11 @@ async fn start(
|
|||||||
"Current batch size"
|
"Current batch size"
|
||||||
);
|
);
|
||||||
metrics::describe_gauge!("tgi_queue_size", metrics::Unit::Count, "Current queue size");
|
metrics::describe_gauge!("tgi_queue_size", metrics::Unit::Count, "Current queue size");
|
||||||
|
metrics::describe_gauge!(
|
||||||
|
"tgi_queue_size_tokens",
|
||||||
|
metrics::Unit::Count,
|
||||||
|
"Current queue size in number of tokens"
|
||||||
|
);
|
||||||
metrics::describe_gauge!(
|
metrics::describe_gauge!(
|
||||||
"tgi_batch_current_max_tokens",
|
"tgi_batch_current_max_tokens",
|
||||||
metrics::Unit::Count,
|
metrics::Unit::Count,
|
||||||
@ -2203,6 +2233,7 @@ async fn start(
|
|||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
|
.route("/state", get(state))
|
||||||
.route("/v1/models", get(openai_get_model_info));
|
.route("/v1/models", get(openai_get_model_info));
|
||||||
|
|
||||||
let compute_type =
|
let compute_type =
|
||||||
|
Loading…
Reference in New Issue
Block a user