mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
Add tgi_batch_current_size and tgi_batch_current_size as response header
This commit is contained in:
parent
c20025dbf7
commit
b4187d6022
@ -3,8 +3,10 @@ use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, Sharded
|
||||
use crate::queue::{Entry, Queue};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::atomic::Ordering::SeqCst;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::server::BATCH_CURRENT_SIZE;
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
@ -155,6 +157,7 @@ pub(crate) async fn batching_task(
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
BATCH_CURRENT_SIZE.store(batch_size, SeqCst);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
@ -227,6 +230,7 @@ pub(crate) async fn batching_task(
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
BATCH_CURRENT_SIZE.store(0, SeqCst);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
}
|
||||
}
|
||||
|
@ -4,8 +4,10 @@ use crate::client::{
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::atomic::Ordering::SeqCst;
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::server::QUEUE_SIZE;
|
||||
use text_generation_router::validation::{
|
||||
ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
};
|
||||
@ -112,6 +114,7 @@ async fn queue_task(
|
||||
QueueCommand::Append(entry, span) => {
|
||||
span.in_scope(|| state.append(*entry));
|
||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||
QUEUE_SIZE.fetch_add(1, SeqCst);
|
||||
}
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
@ -125,6 +128,7 @@ async fn queue_task(
|
||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||
QUEUE_SIZE.store(state.entries.len() as u32, SeqCst);
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
@ -5,8 +5,10 @@ use crate::client::{
|
||||
use crate::queue::{Entry, Queue};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::atomic::Ordering::SeqCst;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::server::BATCH_CURRENT_SIZE;
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
@ -164,6 +166,7 @@ pub(crate) async fn batching_task(
|
||||
let current_tokens = batch.current_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
BATCH_CURRENT_SIZE.store(batch_size, SeqCst);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
@ -284,6 +287,7 @@ pub(crate) async fn batching_task(
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
BATCH_CURRENT_SIZE.store(0, SeqCst);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
}
|
||||
}
|
||||
|
@ -6,8 +6,10 @@ use crate::client::{
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::max;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::atomic::Ordering::SeqCst;
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::server::QUEUE_SIZE;
|
||||
use text_generation_router::validation::{
|
||||
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||
ValidStoppingParameters,
|
||||
@ -140,6 +142,7 @@ async fn queue_task(
|
||||
QueueCommand::Append(entry, span) => {
|
||||
span.in_scope(|| state.append(*entry));
|
||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||
QUEUE_SIZE.fetch_add(1, SeqCst);
|
||||
}
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
@ -155,6 +158,7 @@ async fn queue_task(
|
||||
.await;
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||
QUEUE_SIZE.store(state.entries.len() as u32, SeqCst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -54,6 +54,8 @@ use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::atomic::AtomicU32;
|
||||
use std::sync::atomic::Ordering::SeqCst;
|
||||
use thiserror::Error;
|
||||
use tokio::select;
|
||||
use tokio::signal;
|
||||
@ -400,6 +402,14 @@ pub(crate) async fn generate_internal(
|
||||
"x-generated-tokens",
|
||||
response.generated_text.generated_tokens.into(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-batch-current-size",
|
||||
BATCH_CURRENT_SIZE.load(SeqCst).to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-queue-size",
|
||||
QUEUE_SIZE.load(SeqCst).to_string().parse().unwrap(),
|
||||
);
|
||||
|
||||
// Metrics
|
||||
metrics::counter!("tgi_request_success").increment(1);
|
||||
@ -490,6 +500,9 @@ async fn generate_stream(
|
||||
(headers, sse)
|
||||
}
|
||||
|
||||
pub static BATCH_CURRENT_SIZE: AtomicU32 = AtomicU32::new(0);
|
||||
pub static QUEUE_SIZE: AtomicU32 = AtomicU32::new(0);
|
||||
|
||||
async fn generate_stream_internal(
|
||||
infer: Infer,
|
||||
ComputeType(compute_type): ComputeType,
|
||||
@ -513,6 +526,14 @@ async fn generate_stream_internal(
|
||||
compute_characters.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert("X-Accel-Buffering", "no".parse().unwrap());
|
||||
headers.insert(
|
||||
"x-batch-current-size",
|
||||
BATCH_CURRENT_SIZE.load(SeqCst).to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-queue-size",
|
||||
QUEUE_SIZE.load(SeqCst).to_string().parse().unwrap(),
|
||||
);
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
// Inference
|
||||
|
Loading…
Reference in New Issue
Block a user