Adding AtomicBool to see if healthcheck should do a full roundtrip or

not.
This commit is contained in:
Nicolas Patry 2023-04-26 15:29:03 +02:00
parent 9d613d0f9b
commit 3b2d1a2854
2 changed files with 71 additions and 46 deletions

View File

@ -7,7 +7,10 @@ use flume::SendError;
use futures::future::try_join_all;
use futures::stream::StreamExt;
use nohash_hasher::IntMap;
use std::sync::Arc;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
@ -27,6 +30,8 @@ pub struct Infer {
shared: Arc<Shared>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
/// Has done roundtrip valid run
healthy: Arc<AtomicBool>,
}
/// Infer shared state
@ -63,15 +68,25 @@ impl Infer {
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
let healthy = Arc::new(AtomicBool::new(false));
Self {
validation,
queue,
shared,
limit_concurrent_requests: semaphore,
healthy,
}
}
pub(crate) fn healthy(&self) -> bool {
self.healthy.load(Ordering::SeqCst)
}
pub(crate) fn set_healthy(&self, value: bool) {
self.healthy.store(value, Ordering::SeqCst)
}
/// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip(self))]
pub(crate) async fn generate_stream(

View File

@ -93,53 +93,46 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
)
)]
#[instrument]
#[instrument(skip(infer))]
/// Health check method
async fn health(
mut health: Extension<Health>,
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
health.client.health().await.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "unhealthy".to_string(),
error_type: "healthcheck".to_string(),
}),
)
})?;
Ok(axum::Json(()))
}
#[instrument(skip(infer))]
async fn health_generate(
infer: Extension<Infer>,
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead is check if the gRPC channels are still healthy.
// Send a small inference request
infer
.generate(GenerateRequest {
inputs: "liveness".to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: None,
repetition_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: false,
max_new_tokens: 1,
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: false,
seed: None,
},
})
.await?;
if infer.healthy() {
health.client.health().await.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "unhealthy".to_string(),
error_type: "healthcheck".to_string(),
}),
)
})?;
} else {
infer
.generate(GenerateRequest {
inputs: "liveness".to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: None,
repetition_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: false,
max_new_tokens: 1,
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: false,
seed: None,
},
})
.await?;
infer.set_healthy(true);
}
Ok(axum::Json(()))
}
@ -191,10 +184,27 @@ async fn generate(
// Inference
let (response, best_of_responses) = match req.0.parameters.best_of {
Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
let (response, best_of_responses) = match infer.generate_best_of(req.0, best_of).await {
Ok(result) => result,
Err(err) => {
infer.set_healthy(false);
return Err(err)?;
}
};
(response, Some(best_of_responses))
}
_ => (infer.generate(req.0).await?, None),
_ => (
{
match infer.generate(req.0).await {
Ok(result) => result,
Err(err) => {
infer.set_healthy(false);
return Err(err)?;
}
}
},
None,
),
};
// Token details
@ -473,6 +483,7 @@ async fn generate_stream(
// yield error
Err(err) => {
error = true;
infer.set_healthy(false);
yield Ok(Event::from(err));
break;
}
@ -682,7 +693,6 @@ pub async fn run(
.route("/invocations", post(compat_generate))
// Base Health route
.route("/health", get(health))
.route("/health_generate", get(health_generate))
// Inference API health route
.route("/", get(health))
// AWS Sagemaker health route