mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
New healthcheck that doesn't hit the queue.
This commit is contained in:
parent
7de8a377b0
commit
e1867079fd
@ -15,8 +15,13 @@ service TextGenerationService {
|
|||||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||||
/// Decode token for a list of prefilled batches
|
/// Decode token for a list of prefilled batches
|
||||||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||||
|
/// Health check
|
||||||
|
rpc Health (HealthRequest) returns (HealthResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message HealthRequest {}
|
||||||
|
message HealthResponse {}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
message InfoRequest {}
|
message InfoRequest {}
|
||||||
|
|
||||||
@ -173,4 +178,4 @@ message DecodeResponse {
|
|||||||
repeated Generation generations = 1;
|
repeated Generation generations = 1;
|
||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional Batch batch = 2;
|
optional Batch batch = 2;
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ use tonic::transport::{Channel, Uri};
|
|||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
/// Text Generation Inference gRPC client
|
/// Text Generation Inference gRPC client
|
||||||
#[derive(Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
stub: TextGenerationServiceClient<Channel>,
|
stub: TextGenerationServiceClient<Channel>,
|
||||||
}
|
}
|
||||||
@ -62,6 +62,14 @@ impl Client {
|
|||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get model health
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||||
|
let response = self.stub.health(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
/// Clear the past generations cache
|
/// Clear the past generations cache
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
@ -6,6 +6,7 @@ mod pb;
|
|||||||
mod sharded_client;
|
mod sharded_client;
|
||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
|
pub use pb::generate::v1::HealthResponse;
|
||||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v1::{
|
||||||
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use crate::{Batch, Client, Generation, Request, ShardInfo};
|
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// Text Generation Inference gRPC multi client
|
/// Text Generation Inference gRPC multi client
|
||||||
pub struct ShardedClient {
|
pub struct ShardedClient {
|
||||||
clients: Vec<Client>,
|
clients: Vec<Client>,
|
||||||
@ -48,6 +49,17 @@ impl ShardedClient {
|
|||||||
join_all(futures).await.pop().unwrap()
|
join_all(futures).await.pop().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// GRPC health check
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.health())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
/// Clear the past generations cache
|
/// Clear the past generations cache
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
@ -7,6 +7,7 @@ mod validation;
|
|||||||
use infer::Infer;
|
use infer::Infer;
|
||||||
use queue::{Entry, Queue};
|
use queue::{Entry, Queue};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use text_generation_client::ShardedClient;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
@ -19,6 +20,20 @@ pub struct HubModelInfo {
|
|||||||
pub pipeline_tag: Option<String>,
|
pub pipeline_tag: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
|
pub struct HealthResponse {}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct Health {
|
||||||
|
pub client: ShardedClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Health {
|
||||||
|
pub fn new(client: ShardedClient) -> Self {
|
||||||
|
Self { client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
pub struct Info {
|
pub struct Info {
|
||||||
/// Model info
|
/// Model info
|
||||||
|
@ -3,8 +3,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
|||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
|
GenerateParameters, GenerateRequest, GenerateResponse, Health, HubModelInfo, Infer, Info,
|
||||||
StreamDetails, StreamResponse, Token, Validation,
|
PrefillToken, StreamDetails, StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
@ -82,9 +82,43 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
|||||||
Json(info.0)
|
Json(info.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/health",
|
||||||
|
request_body = HealthRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Everything is working fine"),
|
||||||
|
(status = 500, description = "Text generation inference is down", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "unhealthy"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument]
|
||||||
/// Health check method
|
/// Health check method
|
||||||
|
async fn health(
|
||||||
|
mut health: Extension<Health>,
|
||||||
|
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||||
|
// be a bit too slow for a health check.
|
||||||
|
// What we should do instead is check if the gRPC channels are still healthy.
|
||||||
|
|
||||||
|
// Send a small inference request
|
||||||
|
health.client.health().await.map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "unhealthy".to_string(),
|
||||||
|
error_type: "healthcheck".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
Ok(axum::Json(()))
|
||||||
|
}
|
||||||
|
|
||||||
#[instrument(skip(infer))]
|
#[instrument(skip(infer))]
|
||||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
async fn health_generate(
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||||
// be a bit too slow for a health check.
|
// be a bit too slow for a health check.
|
||||||
// What we should do instead is check if the gRPC channels are still healthy.
|
// What we should do instead is check if the gRPC channels are still healthy.
|
||||||
@ -111,7 +145,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(axum::Json(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
@ -555,6 +589,7 @@ pub async fn run(
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
|
let health_ext = Health::new(client.clone());
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
client,
|
client,
|
||||||
validation,
|
validation,
|
||||||
@ -650,6 +685,7 @@ pub async fn run(
|
|||||||
.route("/invocations", post(compat_generate))
|
.route("/invocations", post(compat_generate))
|
||||||
// Base Health route
|
// Base Health route
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
|
.route("/health_generate", get(health_generate))
|
||||||
// Inference API health route
|
// Inference API health route
|
||||||
.route("/", get(health))
|
.route("/", get(health))
|
||||||
// AWS Sagemaker health route
|
// AWS Sagemaker health route
|
||||||
@ -657,6 +693,7 @@ pub async fn run(
|
|||||||
// Prometheus metrics route
|
// Prometheus metrics route
|
||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
.layer(Extension(info))
|
.layer(Extension(info))
|
||||||
|
.layer(Extension(health_ext))
|
||||||
.layer(Extension(compat_return_full_text))
|
.layer(Extension(compat_return_full_text))
|
||||||
.layer(Extension(infer))
|
.layer(Extension(infer))
|
||||||
.layer(Extension(prom_handle))
|
.layer(Extension(prom_handle))
|
||||||
|
@ -29,6 +29,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
async def Info(self, request, context):
|
async def Info(self, request, context):
|
||||||
return self.model.info
|
return self.model.info
|
||||||
|
|
||||||
|
async def Health(self, request, context):
|
||||||
|
return generate_pb2.HealthResponse()
|
||||||
|
|
||||||
async def ServiceDiscovery(self, request, context):
|
async def ServiceDiscovery(self, request, context):
|
||||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user