diff --git a/router/src/infer.rs b/router/src/infer.rs index 8b44ec86..ce5c5dce 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -7,7 +7,10 @@ use flume::SendError; use futures::future::try_join_all; use futures::stream::StreamExt; use nohash_hasher::IntMap; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; use text_generation_client::{ Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, }; @@ -27,6 +30,8 @@ pub struct Infer { shared: Arc, /// Inference limit limit_concurrent_requests: Arc, + /// Has done roundtrip valid run + healthy: Arc, } /// Infer shared state @@ -63,15 +68,25 @@ impl Infer { // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + let healthy = Arc::new(AtomicBool::new(false)); Self { validation, queue, shared, limit_concurrent_requests: semaphore, + healthy, } } + pub(crate) fn healthy(&self) -> bool { + self.healthy.load(Ordering::SeqCst) + } + + pub(crate) fn set_healthy(&self, value: bool) { + self.healthy.store(value, Ordering::SeqCst) + } + /// Add a new request to the queue and return a stream of InferStreamResponse #[instrument(skip(self))] pub(crate) async fn generate_stream( diff --git a/router/src/server.rs b/router/src/server.rs index 97fd43d0..a0bdf42e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -93,53 +93,46 @@ async fn get_model_info(info: Extension) -> Json { example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), ) )] -#[instrument] +#[instrument(skip(infer))] /// Health check method async fn health( mut health: Extension, -) -> Result, (StatusCode, Json)> { - health.client.health().await.map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "unhealthy".to_string(), - error_type: "healthcheck".to_string(), - }), - ) - })?; - Ok(axum::Json(())) -} - -#[instrument(skip(infer))] -async fn health_generate( infer: Extension, ) -> Result, (StatusCode, Json)> { - // TODO: while this is the best health check we can do, it is a bit on the heavy side and might - // be a bit too slow for a health check. - // What we should do instead is check if the gRPC channels are still healthy. - - // Send a small inference request - infer - .generate(GenerateRequest { - inputs: "liveness".to_string(), - parameters: GenerateParameters { - best_of: None, - temperature: None, - repetition_penalty: None, - top_k: None, - top_p: None, - typical_p: None, - do_sample: false, - max_new_tokens: 1, - return_full_text: None, - stop: Vec::new(), - truncate: None, - watermark: false, - details: false, - seed: None, - }, - }) - .await?; + if infer.healthy() { + health.client.health().await.map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "unhealthy".to_string(), + error_type: "healthcheck".to_string(), + }), + ) + })?; + } else { + infer + .generate(GenerateRequest { + inputs: "liveness".to_string(), + parameters: GenerateParameters { + best_of: None, + temperature: None, + repetition_penalty: None, + top_k: None, + top_p: None, + typical_p: None, + do_sample: false, + max_new_tokens: 1, + return_full_text: None, + stop: Vec::new(), + truncate: None, + watermark: false, + details: false, + seed: None, + }, + }) + .await?; + infer.set_healthy(true); + } Ok(axum::Json(())) } @@ -191,10 +184,27 @@ async fn generate( // Inference let (response, best_of_responses) = match req.0.parameters.best_of { Some(best_of) if best_of > 1 => { - let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?; + let (response, best_of_responses) = match infer.generate_best_of(req.0, best_of).await { + Ok(result) => result, + Err(err) => { + infer.set_healthy(false); + return Err(err)?; + } + }; (response, Some(best_of_responses)) } - _ => (infer.generate(req.0).await?, None), + _ => ( + { + match infer.generate(req.0).await { + Ok(result) => result, + Err(err) => { + infer.set_healthy(false); + return Err(err)?; + } + } + }, + None, + ), }; // Token details @@ -473,6 +483,7 @@ async fn generate_stream( // yield error Err(err) => { error = true; + infer.set_healthy(false); yield Ok(Event::from(err)); break; } @@ -682,7 +693,6 @@ pub async fn run( .route("/invocations", post(compat_generate)) // Base Health route .route("/health", get(health)) - .route("/health_generate", get(health_generate)) // Inference API health route .route("/", get(health)) // AWS Sagemaker health route