add store true when successful prefill/decode

This commit is contained in:
OlivierDehaene 2023-04-26 19:14:21 +02:00
parent 3b2d1a2854
commit e7503a4240
4 changed files with 104 additions and 89 deletions

62
router/src/health.rs Normal file
View File

@ -0,0 +1,62 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
#[derive(Clone, Debug)]
pub(crate) struct Health {
client: ShardedClient,
generation_health: Arc<AtomicBool>,
}
impl Health {
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
Self {
client,
generation_health,
}
}
pub(crate) async fn check(&mut self) -> bool {
if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards are answering gRPC calls
self.client.health().await.is_ok()
} else {
// Generation is unhealthy or have not sent any generation request yet
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
truncate: 10,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
// Skips the queue
let value = self.client.prefill(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
}
}
}

View File

@ -30,8 +30,6 @@ pub struct Infer {
shared: Arc<Shared>, shared: Arc<Shared>,
/// Inference limit /// Inference limit
limit_concurrent_requests: Arc<Semaphore>, limit_concurrent_requests: Arc<Semaphore>,
/// Has done roundtrip valid run
healthy: Arc<AtomicBool>,
} }
/// Infer shared state /// Infer shared state
@ -41,6 +39,7 @@ struct Shared {
} }
impl Infer { impl Infer {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
client: ShardedClient, client: ShardedClient,
validation: Validation, validation: Validation,
@ -49,6 +48,7 @@ impl Infer {
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_concurrent_requests: usize, max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(requires_padding); let queue = Queue::new(requires_padding);
@ -64,29 +64,20 @@ impl Infer {
max_waiting_tokens, max_waiting_tokens,
queue.clone(), queue.clone(),
shared.clone(), shared.clone(),
generation_health,
)); ));
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
let healthy = Arc::new(AtomicBool::new(false));
Self { Self {
validation, validation,
queue, queue,
shared, shared,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
healthy,
} }
} }
pub(crate) fn healthy(&self) -> bool {
self.healthy.load(Ordering::SeqCst)
}
pub(crate) fn set_healthy(&self, value: bool) {
self.healthy.store(value, Ordering::SeqCst)
}
/// Add a new request to the queue and return a stream of InferStreamResponse /// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip(self))] #[instrument(skip(self))]
pub(crate) async fn generate_stream( pub(crate) async fn generate_stream(
@ -255,6 +246,7 @@ async fn batching_task(
max_waiting_tokens: usize, max_waiting_tokens: usize,
queue: Queue, queue: Queue,
shared: Arc<Shared>, shared: Arc<Shared>,
generation_health: Arc<AtomicBool>,
) { ) {
// Infinite loop // Infinite loop
loop { loop {
@ -267,7 +259,7 @@ async fn batching_task(
while let Some((mut entries, batch, span)) = while let Some((mut entries, batch, span)) =
queue.next_batch(None, max_batch_total_tokens).await queue.next_batch(None, max_batch_total_tokens).await
{ {
let mut cached_batch = prefill(&mut client, batch, &mut entries) let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
.instrument(span) .instrument(span)
.await; .await;
let mut waiting_tokens = 1; let mut waiting_tokens = 1;
@ -316,7 +308,8 @@ async fn batching_task(
}); });
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) let new_cached_batch =
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
.instrument(span) .instrument(span)
.await; .await;
// Reset waiting counter // Reset waiting counter
@ -342,7 +335,7 @@ async fn batching_task(
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
}); });
cached_batch = decode(&mut client, batches, &mut entries) cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
.instrument(next_batch_span) .instrument(next_batch_span)
.await; .await;
waiting_tokens += 1; waiting_tokens += 1;
@ -358,6 +351,7 @@ async fn prefill(
client: &mut ShardedClient, client: &mut ShardedClient,
batch: Batch, batch: Batch,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<Batch> { ) -> Option<Batch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
@ -365,6 +359,8 @@ async fn prefill(
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
@ -377,6 +373,8 @@ async fn prefill(
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
// Update health
generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await; let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
@ -390,6 +388,7 @@ async fn decode(
client: &mut ShardedClient, client: &mut ShardedClient,
batches: Vec<Batch>, batches: Vec<Batch>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<Batch> { ) -> Option<Batch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect(); let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
@ -397,6 +396,8 @@ async fn decode(
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
@ -409,6 +410,7 @@ async fn decode(
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
generation_health.store(false, Ordering::SeqCst);
for id in batch_ids { for id in batch_ids {
let _ = client.clear_cache(Some(id)).await; let _ = client.clear_cache(Some(id)).await;
} }

View File

@ -1,3 +1,4 @@
mod health;
/// Text Generation Inference Webserver /// Text Generation Inference Webserver
mod infer; mod infer;
mod queue; mod queue;
@ -7,7 +8,6 @@ mod validation;
use infer::Infer; use infer::Infer;
use queue::{Entry, Queue}; use queue::{Entry, Queue};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use text_generation_client::ShardedClient;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
@ -20,11 +20,6 @@ pub struct HubModelInfo {
pub pipeline_tag: Option<String>, pub pipeline_tag: Option<String>,
} }
#[derive(Clone, Debug)]
pub struct Health {
pub client: ShardedClient,
}
#[derive(Clone, Debug, Serialize, ToSchema)] #[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info { pub struct Info {
/// Model info /// Model info

View File

@ -1,10 +1,11 @@
use crate::health::Health;
/// HTTP Server logic /// HTTP Server logic
use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, Health, HubModelInfo, Infer, Info, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
PrefillToken, StreamDetails, StreamResponse, Token, Validation, StreamDetails, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -18,6 +19,8 @@ use futures::Stream;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient}; use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
@ -86,54 +89,25 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
get, get,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/health", path = "/health",
request_body = HealthRequest,
responses( responses(
(status = 200, description = "Everything is working fine"), (status = 200, description = "Everything is working fine"),
(status = 500, description = "Text generation inference is down", body = ErrorResponse, (status = 503, description = "Text generation inference is down", body = ErrorResponse,
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
) )
)] )]
#[instrument(skip(infer))] #[instrument(skip(health))]
/// Health check method /// Health check method
async fn health( async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
mut health: Extension<Health>, match health.check().await {
infer: Extension<Infer>, true => Ok(()),
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> { false => Err((
if infer.healthy() { StatusCode::SERVICE_UNAVAILABLE,
health.client.health().await.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse { Json(ErrorResponse {
error: "unhealthy".to_string(), error: "unhealthy".to_string(),
error_type: "healthcheck".to_string(), error_type: "healthcheck".to_string(),
}), }),
) )),
})?;
} else {
infer
.generate(GenerateRequest {
inputs: "liveness".to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: None,
repetition_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: false,
max_new_tokens: 1,
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: false,
seed: None,
},
})
.await?;
infer.set_healthy(true);
} }
Ok(axum::Json(()))
} }
/// Generate tokens /// Generate tokens
@ -184,27 +158,10 @@ async fn generate(
// Inference // Inference
let (response, best_of_responses) = match req.0.parameters.best_of { let (response, best_of_responses) = match req.0.parameters.best_of {
Some(best_of) if best_of > 1 => { Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = match infer.generate_best_of(req.0, best_of).await { let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
Ok(result) => result,
Err(err) => {
infer.set_healthy(false);
return Err(err)?;
}
};
(response, Some(best_of_responses)) (response, Some(best_of_responses))
} }
_ => ( _ => (infer.generate(req.0).await?, None),
{
match infer.generate(req.0).await {
Ok(result) => result,
Err(err) => {
infer.set_healthy(false);
return Err(err)?;
}
}
},
None,
),
}; };
// Token details // Token details
@ -483,7 +440,6 @@ async fn generate_stream(
// yield error // yield error
Err(err) => { Err(err) => {
error = true; error = true;
infer.set_healthy(false);
yield Ok(Event::from(err)); yield Ok(Event::from(err));
break; break;
} }
@ -595,9 +551,8 @@ pub async fn run(
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
let health_ext = Health { let healthy = Arc::new(AtomicBool::new(false));
client: client.clone(), let health_ext = Health::new(client.clone(), healthy.clone());
};
let infer = Infer::new( let infer = Infer::new(
client, client,
validation, validation,
@ -606,6 +561,7 @@ pub async fn run(
max_waiting_tokens, max_waiting_tokens,
max_concurrent_requests, max_concurrent_requests,
shard_info.requires_padding, shard_info.requires_padding,
healthy,
); );
// Duration buckets // Duration buckets