mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
add store true when successful prefill/decode
This commit is contained in:
parent
3b2d1a2854
commit
e7503a4240
62
router/src/health.rs
Normal file
62
router/src/health.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct Health {
|
||||
client: ShardedClient,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Health {
|
||||
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
generation_health,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn check(&mut self) -> bool {
|
||||
if self.generation_health.load(Ordering::SeqCst) {
|
||||
// Generation is healthy, we only check that the shards are answering gRPC calls
|
||||
self.client.health().await.is_ok()
|
||||
} else {
|
||||
// Generation is unhealthy or have not sent any generation request yet
|
||||
|
||||
// Dummy batch of 1 token and 1 generated token
|
||||
let liveness_request = Request {
|
||||
id: u64::MAX,
|
||||
inputs: "liveness".to_string(),
|
||||
truncate: 10,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
watermark: false,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
requests: vec![liveness_request],
|
||||
size: 1,
|
||||
max_tokens: 2,
|
||||
};
|
||||
// Skips the queue
|
||||
let value = self.client.prefill(batch).await.is_ok();
|
||||
// Update generation health
|
||||
self.generation_health.store(value, Ordering::SeqCst);
|
||||
value
|
||||
}
|
||||
}
|
||||
}
|
@ -30,8 +30,6 @@ pub struct Infer {
|
||||
shared: Arc<Shared>,
|
||||
/// Inference limit
|
||||
limit_concurrent_requests: Arc<Semaphore>,
|
||||
/// Has done roundtrip valid run
|
||||
healthy: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
/// Infer shared state
|
||||
@ -41,6 +39,7 @@ struct Shared {
|
||||
}
|
||||
|
||||
impl Infer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
validation: Validation,
|
||||
@ -49,6 +48,7 @@ impl Infer {
|
||||
max_waiting_tokens: usize,
|
||||
max_concurrent_requests: usize,
|
||||
requires_padding: bool,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new(requires_padding);
|
||||
@ -64,29 +64,20 @@ impl Infer {
|
||||
max_waiting_tokens,
|
||||
queue.clone(),
|
||||
shared.clone(),
|
||||
generation_health,
|
||||
));
|
||||
|
||||
// Inference limit with a semaphore
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||
let healthy = Arc::new(AtomicBool::new(false));
|
||||
|
||||
Self {
|
||||
validation,
|
||||
queue,
|
||||
shared,
|
||||
limit_concurrent_requests: semaphore,
|
||||
healthy,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn healthy(&self) -> bool {
|
||||
self.healthy.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub(crate) fn set_healthy(&self, value: bool) {
|
||||
self.healthy.store(value, Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn generate_stream(
|
||||
@ -255,6 +246,7 @@ async fn batching_task(
|
||||
max_waiting_tokens: usize,
|
||||
queue: Queue,
|
||||
shared: Arc<Shared>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
@ -267,7 +259,7 @@ async fn batching_task(
|
||||
while let Some((mut entries, batch, span)) =
|
||||
queue.next_batch(None, max_batch_total_tokens).await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
@ -316,9 +308,10 @@ async fn batching_task(
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let new_cached_batch =
|
||||
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
@ -342,7 +335,7 @@ async fn batching_task(
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
@ -358,6 +351,7 @@ async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<Batch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
@ -365,6 +359,8 @@ async fn prefill(
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
@ -377,6 +373,8 @@ async fn prefill(
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
// Update health
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||
@ -390,6 +388,7 @@ async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<Batch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<Batch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
@ -397,6 +396,8 @@ async fn decode(
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
@ -409,6 +410,7 @@ async fn decode(
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
mod health;
|
||||
/// Text Generation Inference Webserver
|
||||
mod infer;
|
||||
mod queue;
|
||||
@ -7,7 +8,6 @@ mod validation;
|
||||
use infer::Infer;
|
||||
use queue::{Entry, Queue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use text_generation_client::ShardedClient;
|
||||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
|
||||
@ -20,11 +20,6 @@ pub struct HubModelInfo {
|
||||
pub pipeline_tag: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Health {
|
||||
pub client: ShardedClient,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct Info {
|
||||
/// Model info
|
||||
|
@ -1,10 +1,11 @@
|
||||
use crate::health::Health;
|
||||
/// HTTP Server logic
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, Health, HubModelInfo, Infer, Info,
|
||||
PrefillToken, StreamDetails, StreamResponse, Token, Validation,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
|
||||
StreamDetails, StreamResponse, Token, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
@ -18,6 +19,8 @@ use futures::Stream;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{ShardInfo, ShardedClient};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
@ -86,54 +89,25 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/health",
|
||||
request_body = HealthRequest,
|
||||
responses(
|
||||
(status = 200, description = "Everything is working fine"),
|
||||
(status = 500, description = "Text generation inference is down", body = ErrorResponse,
|
||||
(status = 503, description = "Text generation inference is down", body = ErrorResponse,
|
||||
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(skip(infer))]
|
||||
#[instrument(skip(health))]
|
||||
/// Health check method
|
||||
async fn health(
|
||||
mut health: Extension<Health>,
|
||||
infer: Extension<Infer>,
|
||||
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
||||
if infer.healthy() {
|
||||
health.client.health().await.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "unhealthy".to_string(),
|
||||
error_type: "healthcheck".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
} else {
|
||||
infer
|
||||
.generate(GenerateRequest {
|
||||
inputs: "liveness".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: None,
|
||||
repetition_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: false,
|
||||
max_new_tokens: 1,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: false,
|
||||
seed: None,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
infer.set_healthy(true);
|
||||
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
match health.check().await {
|
||||
true => Ok(()),
|
||||
false => Err((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(ErrorResponse {
|
||||
error: "unhealthy".to_string(),
|
||||
error_type: "healthcheck".to_string(),
|
||||
}),
|
||||
)),
|
||||
}
|
||||
Ok(axum::Json(()))
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
@ -184,27 +158,10 @@ async fn generate(
|
||||
// Inference
|
||||
let (response, best_of_responses) = match req.0.parameters.best_of {
|
||||
Some(best_of) if best_of > 1 => {
|
||||
let (response, best_of_responses) = match infer.generate_best_of(req.0, best_of).await {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
infer.set_healthy(false);
|
||||
return Err(err)?;
|
||||
}
|
||||
};
|
||||
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
|
||||
(response, Some(best_of_responses))
|
||||
}
|
||||
_ => (
|
||||
{
|
||||
match infer.generate(req.0).await {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
infer.set_healthy(false);
|
||||
return Err(err)?;
|
||||
}
|
||||
}
|
||||
},
|
||||
None,
|
||||
),
|
||||
_ => (infer.generate(req.0).await?, None),
|
||||
};
|
||||
|
||||
// Token details
|
||||
@ -483,7 +440,6 @@ async fn generate_stream(
|
||||
// yield error
|
||||
Err(err) => {
|
||||
error = true;
|
||||
infer.set_healthy(false);
|
||||
yield Ok(Event::from(err));
|
||||
break;
|
||||
}
|
||||
@ -595,9 +551,8 @@ pub async fn run(
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
let health_ext = Health {
|
||||
client: client.clone(),
|
||||
};
|
||||
let healthy = Arc::new(AtomicBool::new(false));
|
||||
let health_ext = Health::new(client.clone(), healthy.clone());
|
||||
let infer = Infer::new(
|
||||
client,
|
||||
validation,
|
||||
@ -606,6 +561,7 @@ pub async fn run(
|
||||
max_waiting_tokens,
|
||||
max_concurrent_requests,
|
||||
shard_info.requires_padding,
|
||||
healthy,
|
||||
);
|
||||
|
||||
// Duration buckets
|
||||
|
Loading…
Reference in New Issue
Block a user