mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Adding AtomicBool to see if healthcheck should do a full roundtrip or
not.
This commit is contained in:
parent
9d613d0f9b
commit
3b2d1a2854
@ -7,7 +7,10 @@ use flume::SendError;
|
|||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
};
|
};
|
||||||
@ -27,6 +30,8 @@ pub struct Infer {
|
|||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
/// Inference limit
|
/// Inference limit
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
|
/// Has done roundtrip valid run
|
||||||
|
healthy: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer shared state
|
/// Infer shared state
|
||||||
@ -63,15 +68,25 @@ impl Infer {
|
|||||||
|
|
||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
let healthy = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
queue,
|
queue,
|
||||||
shared,
|
shared,
|
||||||
limit_concurrent_requests: semaphore,
|
limit_concurrent_requests: semaphore,
|
||||||
|
healthy,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn healthy(&self) -> bool {
|
||||||
|
self.healthy.load(Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn set_healthy(&self, value: bool) {
|
||||||
|
self.healthy.store(value, Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub(crate) async fn generate_stream(
|
pub(crate) async fn generate_stream(
|
||||||
|
@ -93,53 +93,46 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
|||||||
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument]
|
#[instrument(skip(infer))]
|
||||||
/// Health check method
|
/// Health check method
|
||||||
async fn health(
|
async fn health(
|
||||||
mut health: Extension<Health>,
|
mut health: Extension<Health>,
|
||||||
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
|
||||||
health.client.health().await.map_err(|_| {
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "unhealthy".to_string(),
|
|
||||||
error_type: "healthcheck".to_string(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
Ok(axum::Json(()))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip(infer))]
|
|
||||||
async fn health_generate(
|
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
if infer.healthy() {
|
||||||
// be a bit too slow for a health check.
|
health.client.health().await.map_err(|_| {
|
||||||
// What we should do instead is check if the gRPC channels are still healthy.
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
// Send a small inference request
|
Json(ErrorResponse {
|
||||||
infer
|
error: "unhealthy".to_string(),
|
||||||
.generate(GenerateRequest {
|
error_type: "healthcheck".to_string(),
|
||||||
inputs: "liveness".to_string(),
|
}),
|
||||||
parameters: GenerateParameters {
|
)
|
||||||
best_of: None,
|
})?;
|
||||||
temperature: None,
|
} else {
|
||||||
repetition_penalty: None,
|
infer
|
||||||
top_k: None,
|
.generate(GenerateRequest {
|
||||||
top_p: None,
|
inputs: "liveness".to_string(),
|
||||||
typical_p: None,
|
parameters: GenerateParameters {
|
||||||
do_sample: false,
|
best_of: None,
|
||||||
max_new_tokens: 1,
|
temperature: None,
|
||||||
return_full_text: None,
|
repetition_penalty: None,
|
||||||
stop: Vec::new(),
|
top_k: None,
|
||||||
truncate: None,
|
top_p: None,
|
||||||
watermark: false,
|
typical_p: None,
|
||||||
details: false,
|
do_sample: false,
|
||||||
seed: None,
|
max_new_tokens: 1,
|
||||||
},
|
return_full_text: None,
|
||||||
})
|
stop: Vec::new(),
|
||||||
.await?;
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
|
details: false,
|
||||||
|
seed: None,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
infer.set_healthy(true);
|
||||||
|
}
|
||||||
Ok(axum::Json(()))
|
Ok(axum::Json(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,10 +184,27 @@ async fn generate(
|
|||||||
// Inference
|
// Inference
|
||||||
let (response, best_of_responses) = match req.0.parameters.best_of {
|
let (response, best_of_responses) = match req.0.parameters.best_of {
|
||||||
Some(best_of) if best_of > 1 => {
|
Some(best_of) if best_of > 1 => {
|
||||||
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
|
let (response, best_of_responses) = match infer.generate_best_of(req.0, best_of).await {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(err) => {
|
||||||
|
infer.set_healthy(false);
|
||||||
|
return Err(err)?;
|
||||||
|
}
|
||||||
|
};
|
||||||
(response, Some(best_of_responses))
|
(response, Some(best_of_responses))
|
||||||
}
|
}
|
||||||
_ => (infer.generate(req.0).await?, None),
|
_ => (
|
||||||
|
{
|
||||||
|
match infer.generate(req.0).await {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(err) => {
|
||||||
|
infer.set_healthy(false);
|
||||||
|
return Err(err)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Token details
|
// Token details
|
||||||
@ -473,6 +483,7 @@ async fn generate_stream(
|
|||||||
// yield error
|
// yield error
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error = true;
|
error = true;
|
||||||
|
infer.set_healthy(false);
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -682,7 +693,6 @@ pub async fn run(
|
|||||||
.route("/invocations", post(compat_generate))
|
.route("/invocations", post(compat_generate))
|
||||||
// Base Health route
|
// Base Health route
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.route("/health_generate", get(health_generate))
|
|
||||||
// Inference API health route
|
// Inference API health route
|
||||||
.route("/", get(health))
|
.route("/", get(health))
|
||||||
// AWS Sagemaker health route
|
// AWS Sagemaker health route
|
||||||
|
Loading…
Reference in New Issue
Block a user