mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
add store true when successful prefill/decode
This commit is contained in:
parent
3b2d1a2854
commit
e7503a4240
62
router/src/health.rs
Normal file
62
router/src/health.rs
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use text_generation_client::{
|
||||||
|
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub(crate) struct Health {
|
||||||
|
client: ShardedClient,
|
||||||
|
generation_health: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Health {
|
||||||
|
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
|
||||||
|
Self {
|
||||||
|
client,
|
||||||
|
generation_health,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn check(&mut self) -> bool {
|
||||||
|
if self.generation_health.load(Ordering::SeqCst) {
|
||||||
|
// Generation is healthy, we only check that the shards are answering gRPC calls
|
||||||
|
self.client.health().await.is_ok()
|
||||||
|
} else {
|
||||||
|
// Generation is unhealthy or have not sent any generation request yet
|
||||||
|
|
||||||
|
// Dummy batch of 1 token and 1 generated token
|
||||||
|
let liveness_request = Request {
|
||||||
|
id: u64::MAX,
|
||||||
|
inputs: "liveness".to_string(),
|
||||||
|
truncate: 10,
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 1.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 1.0,
|
||||||
|
typical_p: 1.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
watermark: false,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: 1,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: false,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
let batch = Batch {
|
||||||
|
id: u64::MAX,
|
||||||
|
requests: vec![liveness_request],
|
||||||
|
size: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
};
|
||||||
|
// Skips the queue
|
||||||
|
let value = self.client.prefill(batch).await.is_ok();
|
||||||
|
// Update generation health
|
||||||
|
self.generation_health.store(value, Ordering::SeqCst);
|
||||||
|
value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -30,8 +30,6 @@ pub struct Infer {
|
|||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
/// Inference limit
|
/// Inference limit
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
/// Has done roundtrip valid run
|
|
||||||
healthy: Arc<AtomicBool>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer shared state
|
/// Infer shared state
|
||||||
@ -41,6 +39,7 @@ struct Shared {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Infer {
|
impl Infer {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
@ -49,6 +48,7 @@ impl Infer {
|
|||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding);
|
let queue = Queue::new(requires_padding);
|
||||||
@ -64,29 +64,20 @@ impl Infer {
|
|||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
shared.clone(),
|
shared.clone(),
|
||||||
|
generation_health,
|
||||||
));
|
));
|
||||||
|
|
||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
let healthy = Arc::new(AtomicBool::new(false));
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
queue,
|
queue,
|
||||||
shared,
|
shared,
|
||||||
limit_concurrent_requests: semaphore,
|
limit_concurrent_requests: semaphore,
|
||||||
healthy,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn healthy(&self) -> bool {
|
|
||||||
self.healthy.load(Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn set_healthy(&self, value: bool) {
|
|
||||||
self.healthy.store(value, Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub(crate) async fn generate_stream(
|
pub(crate) async fn generate_stream(
|
||||||
@ -255,6 +246,7 @@ async fn batching_task(
|
|||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
|
generation_health: Arc<AtomicBool>,
|
||||||
) {
|
) {
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
loop {
|
loop {
|
||||||
@ -267,7 +259,7 @@ async fn batching_task(
|
|||||||
while let Some((mut entries, batch, span)) =
|
while let Some((mut entries, batch, span)) =
|
||||||
queue.next_batch(None, max_batch_total_tokens).await
|
queue.next_batch(None, max_batch_total_tokens).await
|
||||||
{
|
{
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
@ -316,9 +308,10 @@ async fn batching_task(
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
let new_cached_batch =
|
||||||
.instrument(span)
|
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
|
||||||
.await;
|
.instrument(span)
|
||||||
|
.await;
|
||||||
// Reset waiting counter
|
// Reset waiting counter
|
||||||
waiting_tokens = 1;
|
waiting_tokens = 1;
|
||||||
// Extend current batch with the new batch
|
// Extend current batch with the new batch
|
||||||
@ -342,7 +335,7 @@ async fn batching_task(
|
|||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
});
|
});
|
||||||
|
|
||||||
cached_batch = decode(&mut client, batches, &mut entries)
|
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
|
||||||
.instrument(next_batch_span)
|
.instrument(next_batch_span)
|
||||||
.await;
|
.await;
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
@ -358,6 +351,7 @@ async fn prefill(
|
|||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
generation_health: &Arc<AtomicBool>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
@ -365,6 +359,8 @@ async fn prefill(
|
|||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
|
// Update health
|
||||||
|
generation_health.store(true, Ordering::SeqCst);
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
@ -377,6 +373,8 @@ async fn prefill(
|
|||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
// Update health
|
||||||
|
generation_health.store(false, Ordering::SeqCst);
|
||||||
let _ = client.clear_cache(Some(batch_id)).await;
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||||
@ -390,6 +388,7 @@ async fn decode(
|
|||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<Batch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
generation_health: &Arc<AtomicBool>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
@ -397,6 +396,8 @@ async fn decode(
|
|||||||
|
|
||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
|
// Update health
|
||||||
|
generation_health.store(true, Ordering::SeqCst);
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
@ -409,6 +410,7 @@ async fn decode(
|
|||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
generation_health.store(false, Ordering::SeqCst);
|
||||||
for id in batch_ids {
|
for id in batch_ids {
|
||||||
let _ = client.clear_cache(Some(id)).await;
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
mod health;
|
||||||
/// Text Generation Inference Webserver
|
/// Text Generation Inference Webserver
|
||||||
mod infer;
|
mod infer;
|
||||||
mod queue;
|
mod queue;
|
||||||
@ -7,7 +8,6 @@ mod validation;
|
|||||||
use infer::Infer;
|
use infer::Infer;
|
||||||
use queue::{Entry, Queue};
|
use queue::{Entry, Queue};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use text_generation_client::ShardedClient;
|
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
@ -20,11 +20,6 @@ pub struct HubModelInfo {
|
|||||||
pub pipeline_tag: Option<String>,
|
pub pipeline_tag: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct Health {
|
|
||||||
pub client: ShardedClient,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
pub struct Info {
|
pub struct Info {
|
||||||
/// Model info
|
/// Model info
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
|
use crate::health::Health;
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, Health, HubModelInfo, Infer, Info,
|
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
|
||||||
PrefillToken, StreamDetails, StreamResponse, Token, Validation,
|
StreamDetails, StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
@ -18,6 +19,8 @@ use futures::Stream;
|
|||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::atomic::AtomicBool;
|
||||||
|
use std::sync::Arc;
|
||||||
use text_generation_client::{ShardInfo, ShardedClient};
|
use text_generation_client::{ShardInfo, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
@ -86,54 +89,25 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
|||||||
get,
|
get,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/health",
|
path = "/health",
|
||||||
request_body = HealthRequest,
|
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Everything is working fine"),
|
(status = 200, description = "Everything is working fine"),
|
||||||
(status = 500, description = "Text generation inference is down", body = ErrorResponse,
|
(status = 503, description = "Text generation inference is down", body = ErrorResponse,
|
||||||
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(infer))]
|
#[instrument(skip(health))]
|
||||||
/// Health check method
|
/// Health check method
|
||||||
async fn health(
|
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||||
mut health: Extension<Health>,
|
match health.check().await {
|
||||||
infer: Extension<Infer>,
|
true => Ok(()),
|
||||||
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
false => Err((
|
||||||
if infer.healthy() {
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
health.client.health().await.map_err(|_| {
|
Json(ErrorResponse {
|
||||||
(
|
error: "unhealthy".to_string(),
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
error_type: "healthcheck".to_string(),
|
||||||
Json(ErrorResponse {
|
}),
|
||||||
error: "unhealthy".to_string(),
|
)),
|
||||||
error_type: "healthcheck".to_string(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
} else {
|
|
||||||
infer
|
|
||||||
.generate(GenerateRequest {
|
|
||||||
inputs: "liveness".to_string(),
|
|
||||||
parameters: GenerateParameters {
|
|
||||||
best_of: None,
|
|
||||||
temperature: None,
|
|
||||||
repetition_penalty: None,
|
|
||||||
top_k: None,
|
|
||||||
top_p: None,
|
|
||||||
typical_p: None,
|
|
||||||
do_sample: false,
|
|
||||||
max_new_tokens: 1,
|
|
||||||
return_full_text: None,
|
|
||||||
stop: Vec::new(),
|
|
||||||
truncate: None,
|
|
||||||
watermark: false,
|
|
||||||
details: false,
|
|
||||||
seed: None,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
infer.set_healthy(true);
|
|
||||||
}
|
}
|
||||||
Ok(axum::Json(()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
@ -184,27 +158,10 @@ async fn generate(
|
|||||||
// Inference
|
// Inference
|
||||||
let (response, best_of_responses) = match req.0.parameters.best_of {
|
let (response, best_of_responses) = match req.0.parameters.best_of {
|
||||||
Some(best_of) if best_of > 1 => {
|
Some(best_of) if best_of > 1 => {
|
||||||
let (response, best_of_responses) = match infer.generate_best_of(req.0, best_of).await {
|
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
|
||||||
Ok(result) => result,
|
|
||||||
Err(err) => {
|
|
||||||
infer.set_healthy(false);
|
|
||||||
return Err(err)?;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
(response, Some(best_of_responses))
|
(response, Some(best_of_responses))
|
||||||
}
|
}
|
||||||
_ => (
|
_ => (infer.generate(req.0).await?, None),
|
||||||
{
|
|
||||||
match infer.generate(req.0).await {
|
|
||||||
Ok(result) => result,
|
|
||||||
Err(err) => {
|
|
||||||
infer.set_healthy(false);
|
|
||||||
return Err(err)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Token details
|
// Token details
|
||||||
@ -483,7 +440,6 @@ async fn generate_stream(
|
|||||||
// yield error
|
// yield error
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error = true;
|
error = true;
|
||||||
infer.set_healthy(false);
|
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -595,9 +551,8 @@ pub async fn run(
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
let health_ext = Health {
|
let healthy = Arc::new(AtomicBool::new(false));
|
||||||
client: client.clone(),
|
let health_ext = Health::new(client.clone(), healthy.clone());
|
||||||
};
|
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
client,
|
client,
|
||||||
validation,
|
validation,
|
||||||
@ -606,6 +561,7 @@ pub async fn run(
|
|||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
|
healthy,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Duration buckets
|
// Duration buckets
|
||||||
|
Loading…
Reference in New Issue
Block a user