mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
Adding AtomicBool to see if healthcheck should do a full roundtrip or
not.
This commit is contained in:
parent
9d613d0f9b
commit
3b2d1a2854
@ -7,7 +7,10 @@ use flume::SendError;
|
||||
use futures::future::try_join_all;
|
||||
use futures::stream::StreamExt;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
};
|
||||
@ -27,6 +30,8 @@ pub struct Infer {
|
||||
shared: Arc<Shared>,
|
||||
/// Inference limit
|
||||
limit_concurrent_requests: Arc<Semaphore>,
|
||||
/// Has done roundtrip valid run
|
||||
healthy: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
/// Infer shared state
|
||||
@ -63,15 +68,25 @@ impl Infer {
|
||||
|
||||
// Inference limit with a semaphore
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||
let healthy = Arc::new(AtomicBool::new(false));
|
||||
|
||||
Self {
|
||||
validation,
|
||||
queue,
|
||||
shared,
|
||||
limit_concurrent_requests: semaphore,
|
||||
healthy,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn healthy(&self) -> bool {
|
||||
self.healthy.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub(crate) fn set_healthy(&self, value: bool) {
|
||||
self.healthy.store(value, Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn generate_stream(
|
||||
|
@ -93,53 +93,46 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
||||
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||
)
|
||||
)]
|
||||
#[instrument]
|
||||
#[instrument(skip(infer))]
|
||||
/// Health check method
|
||||
async fn health(
|
||||
mut health: Extension<Health>,
|
||||
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
||||
health.client.health().await.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "unhealthy".to_string(),
|
||||
error_type: "healthcheck".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
Ok(axum::Json(()))
|
||||
}
|
||||
|
||||
#[instrument(skip(infer))]
|
||||
async fn health_generate(
|
||||
infer: Extension<Infer>,
|
||||
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||
// be a bit too slow for a health check.
|
||||
// What we should do instead is check if the gRPC channels are still healthy.
|
||||
|
||||
// Send a small inference request
|
||||
infer
|
||||
.generate(GenerateRequest {
|
||||
inputs: "liveness".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: None,
|
||||
repetition_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: false,
|
||||
max_new_tokens: 1,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: false,
|
||||
seed: None,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
if infer.healthy() {
|
||||
health.client.health().await.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "unhealthy".to_string(),
|
||||
error_type: "healthcheck".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
} else {
|
||||
infer
|
||||
.generate(GenerateRequest {
|
||||
inputs: "liveness".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: None,
|
||||
repetition_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: false,
|
||||
max_new_tokens: 1,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: false,
|
||||
seed: None,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
infer.set_healthy(true);
|
||||
}
|
||||
Ok(axum::Json(()))
|
||||
}
|
||||
|
||||
@ -191,10 +184,27 @@ async fn generate(
|
||||
// Inference
|
||||
let (response, best_of_responses) = match req.0.parameters.best_of {
|
||||
Some(best_of) if best_of > 1 => {
|
||||
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
|
||||
let (response, best_of_responses) = match infer.generate_best_of(req.0, best_of).await {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
infer.set_healthy(false);
|
||||
return Err(err)?;
|
||||
}
|
||||
};
|
||||
(response, Some(best_of_responses))
|
||||
}
|
||||
_ => (infer.generate(req.0).await?, None),
|
||||
_ => (
|
||||
{
|
||||
match infer.generate(req.0).await {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
infer.set_healthy(false);
|
||||
return Err(err)?;
|
||||
}
|
||||
}
|
||||
},
|
||||
None,
|
||||
),
|
||||
};
|
||||
|
||||
// Token details
|
||||
@ -473,6 +483,7 @@ async fn generate_stream(
|
||||
// yield error
|
||||
Err(err) => {
|
||||
error = true;
|
||||
infer.set_healthy(false);
|
||||
yield Ok(Event::from(err));
|
||||
break;
|
||||
}
|
||||
@ -682,7 +693,6 @@ pub async fn run(
|
||||
.route("/invocations", post(compat_generate))
|
||||
// Base Health route
|
||||
.route("/health", get(health))
|
||||
.route("/health_generate", get(health_generate))
|
||||
// Inference API health route
|
||||
.route("/", get(health))
|
||||
// AWS Sagemaker health route
|
||||
|
Loading…
Reference in New Issue
Block a user