mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
New healthcheck that doesn't hit the queue.
This commit is contained in:
parent
7de8a377b0
commit
e1867079fd
@ -15,8 +15,13 @@ service TextGenerationService {
|
||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||
/// Decode token for a list of prefilled batches
|
||||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||
/// Health check
|
||||
rpc Health (HealthRequest) returns (HealthResponse);
|
||||
}
|
||||
|
||||
message HealthRequest {}
|
||||
message HealthResponse {}
|
||||
|
||||
/// Empty request
|
||||
message InfoRequest {}
|
||||
|
||||
@ -173,4 +178,4 @@ message DecodeResponse {
|
||||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional Batch batch = 2;
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
stub: TextGenerationServiceClient<Channel>,
|
||||
}
|
||||
@ -62,6 +62,14 @@ impl Client {
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Get model health
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||
let response = self.stub.health(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
|
@ -6,6 +6,7 @@ mod pb;
|
||||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v1::HealthResponse;
|
||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v1::{
|
||||
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
||||
|
@ -1,10 +1,11 @@
|
||||
/// Multi shard Client
|
||||
use crate::Result;
|
||||
use crate::{Batch, Client, Generation, Request, ShardInfo};
|
||||
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Text Generation Inference gRPC multi client
|
||||
pub struct ShardedClient {
|
||||
clients: Vec<Client>,
|
||||
@ -48,6 +49,17 @@ impl ShardedClient {
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// GRPC health check
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.health())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
|
@ -7,6 +7,7 @@ mod validation;
|
||||
use infer::Infer;
|
||||
use queue::{Entry, Queue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use text_generation_client::ShardedClient;
|
||||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
|
||||
@ -19,6 +20,20 @@ pub struct HubModelInfo {
|
||||
pub pipeline_tag: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct HealthResponse {}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Health {
|
||||
pub client: ShardedClient,
|
||||
}
|
||||
|
||||
impl Health {
|
||||
pub fn new(client: ShardedClient) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct Info {
|
||||
/// Model info
|
||||
|
@ -3,8 +3,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
|
||||
StreamDetails, StreamResponse, Token, Validation,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, Health, HubModelInfo, Infer, Info,
|
||||
PrefillToken, StreamDetails, StreamResponse, Token, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
@ -82,9 +82,43 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
||||
Json(info.0)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/health",
|
||||
request_body = HealthRequest,
|
||||
responses(
|
||||
(status = 200, description = "Everything is working fine"),
|
||||
(status = 500, description = "Text generation inference is down", body = ErrorResponse,
|
||||
example = json ! ({"error": "unhealthy"})),
|
||||
)
|
||||
)]
|
||||
#[instrument]
|
||||
/// Health check method
|
||||
async fn health(
|
||||
mut health: Extension<Health>,
|
||||
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||
// be a bit too slow for a health check.
|
||||
// What we should do instead is check if the gRPC channels are still healthy.
|
||||
|
||||
// Send a small inference request
|
||||
health.client.health().await.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "unhealthy".to_string(),
|
||||
error_type: "healthcheck".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
Ok(axum::Json(()))
|
||||
}
|
||||
|
||||
#[instrument(skip(infer))]
|
||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
async fn health_generate(
|
||||
infer: Extension<Infer>,
|
||||
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||
// be a bit too slow for a health check.
|
||||
// What we should do instead is check if the gRPC channels are still healthy.
|
||||
@ -111,7 +145,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
Ok(axum::Json(()))
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
@ -555,6 +589,7 @@ pub async fn run(
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
let health_ext = Health::new(client.clone());
|
||||
let infer = Infer::new(
|
||||
client,
|
||||
validation,
|
||||
@ -650,6 +685,7 @@ pub async fn run(
|
||||
.route("/invocations", post(compat_generate))
|
||||
// Base Health route
|
||||
.route("/health", get(health))
|
||||
.route("/health_generate", get(health_generate))
|
||||
// Inference API health route
|
||||
.route("/", get(health))
|
||||
// AWS Sagemaker health route
|
||||
@ -657,6 +693,7 @@ pub async fn run(
|
||||
// Prometheus metrics route
|
||||
.route("/metrics", get(metrics))
|
||||
.layer(Extension(info))
|
||||
.layer(Extension(health_ext))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
.layer(Extension(infer))
|
||||
.layer(Extension(prom_handle))
|
||||
|
@ -29,6 +29,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
async def Info(self, request, context):
|
||||
return self.model.info
|
||||
|
||||
async def Health(self, request, context):
|
||||
return generate_pb2.HealthResponse()
|
||||
|
||||
async def ServiceDiscovery(self, request, context):
|
||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user