Adding AtomicBool to see if healthcheck should do a full roundtrip or

not.
This commit is contained in:
Nicolas Patry 2023-04-26 15:29:03 +02:00
parent 9d613d0f9b
commit 3b2d1a2854
2 changed files with 71 additions and 46 deletions

View File

@ -7,7 +7,10 @@ use flume::SendError;
use futures::future::try_join_all; use futures::future::try_join_all;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::{ use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
}; };
@ -27,6 +30,8 @@ pub struct Infer {
shared: Arc<Shared>, shared: Arc<Shared>,
/// Inference limit /// Inference limit
limit_concurrent_requests: Arc<Semaphore>, limit_concurrent_requests: Arc<Semaphore>,
/// Has done roundtrip valid run
healthy: Arc<AtomicBool>,
} }
/// Infer shared state /// Infer shared state
@ -63,15 +68,25 @@ impl Infer {
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
let healthy = Arc::new(AtomicBool::new(false));
Self { Self {
validation, validation,
queue, queue,
shared, shared,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
healthy,
} }
} }
pub(crate) fn healthy(&self) -> bool {
self.healthy.load(Ordering::SeqCst)
}
pub(crate) fn set_healthy(&self, value: bool) {
self.healthy.store(value, Ordering::SeqCst)
}
/// Add a new request to the queue and return a stream of InferStreamResponse /// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip(self))] #[instrument(skip(self))]
pub(crate) async fn generate_stream( pub(crate) async fn generate_stream(

View File

@ -93,53 +93,46 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
) )
)] )]
#[instrument] #[instrument(skip(infer))]
/// Health check method /// Health check method
async fn health( async fn health(
mut health: Extension<Health>, mut health: Extension<Health>,
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
health.client.health().await.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "unhealthy".to_string(),
error_type: "healthcheck".to_string(),
}),
)
})?;
Ok(axum::Json(()))
}
#[instrument(skip(infer))]
async fn health_generate(
infer: Extension<Infer>, infer: Extension<Infer>,
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might if infer.healthy() {
// be a bit too slow for a health check. health.client.health().await.map_err(|_| {
// What we should do instead is check if the gRPC channels are still healthy. (
StatusCode::INTERNAL_SERVER_ERROR,
// Send a small inference request Json(ErrorResponse {
infer error: "unhealthy".to_string(),
.generate(GenerateRequest { error_type: "healthcheck".to_string(),
inputs: "liveness".to_string(), }),
parameters: GenerateParameters { )
best_of: None, })?;
temperature: None, } else {
repetition_penalty: None, infer
top_k: None, .generate(GenerateRequest {
top_p: None, inputs: "liveness".to_string(),
typical_p: None, parameters: GenerateParameters {
do_sample: false, best_of: None,
max_new_tokens: 1, temperature: None,
return_full_text: None, repetition_penalty: None,
stop: Vec::new(), top_k: None,
truncate: None, top_p: None,
watermark: false, typical_p: None,
details: false, do_sample: false,
seed: None, max_new_tokens: 1,
}, return_full_text: None,
}) stop: Vec::new(),
.await?; truncate: None,
watermark: false,
details: false,
seed: None,
},
})
.await?;
infer.set_healthy(true);
}
Ok(axum::Json(())) Ok(axum::Json(()))
} }
@ -191,10 +184,27 @@ async fn generate(
// Inference // Inference
let (response, best_of_responses) = match req.0.parameters.best_of { let (response, best_of_responses) = match req.0.parameters.best_of {
Some(best_of) if best_of > 1 => { Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?; let (response, best_of_responses) = match infer.generate_best_of(req.0, best_of).await {
Ok(result) => result,
Err(err) => {
infer.set_healthy(false);
return Err(err)?;
}
};
(response, Some(best_of_responses)) (response, Some(best_of_responses))
} }
_ => (infer.generate(req.0).await?, None), _ => (
{
match infer.generate(req.0).await {
Ok(result) => result,
Err(err) => {
infer.set_healthy(false);
return Err(err)?;
}
}
},
None,
),
}; };
// Token details // Token details
@ -473,6 +483,7 @@ async fn generate_stream(
// yield error // yield error
Err(err) => { Err(err) => {
error = true; error = true;
infer.set_healthy(false);
yield Ok(Event::from(err)); yield Ok(Event::from(err));
break; break;
} }
@ -682,7 +693,6 @@ pub async fn run(
.route("/invocations", post(compat_generate)) .route("/invocations", post(compat_generate))
// Base Health route // Base Health route
.route("/health", get(health)) .route("/health", get(health))
.route("/health_generate", get(health_generate))
// Inference API health route // Inference API health route
.route("/", get(health)) .route("/", get(health))
// AWS Sagemaker health route // AWS Sagemaker health route