New healthcheck that doesn't hit the queue.

This commit is contained in:
Nicolas Patry 2023-04-26 12:17:00 +02:00
parent 7de8a377b0
commit e1867079fd
7 changed files with 88 additions and 7 deletions

View File

@ -15,8 +15,13 @@ service TextGenerationService {
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
}
message HealthRequest {}
message HealthResponse {}
/// Empty request
message InfoRequest {}
@ -173,4 +178,4 @@ message DecodeResponse {
repeated Generation generations = 1;
/// Next batch (cached)
optional Batch batch = 2;
}
}

View File

@ -7,7 +7,7 @@ use tonic::transport::{Channel, Uri};
use tracing::instrument;
/// Text Generation Inference gRPC client
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct Client {
stub: TextGenerationServiceClient<Channel>,
}
@ -62,6 +62,14 @@ impl Client {
Ok(response)
}
/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let request = tonic::Request::new(HealthRequest {}).inject_context();
let response = self.stub.health(request).await?.into_inner();
Ok(response)
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {

View File

@ -6,6 +6,7 @@ mod pb;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,

View File

@ -1,10 +1,11 @@
/// Multi shard Client
use crate::Result;
use crate::{Batch, Client, Generation, Request, ShardInfo};
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
clients: Vec<Client>,
@ -48,6 +49,17 @@ impl ShardedClient {
join_all(futures).await.pop().unwrap()
}
/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {

View File

@ -7,6 +7,7 @@ mod validation;
use infer::Infer;
use queue::{Entry, Queue};
use serde::{Deserialize, Serialize};
use text_generation_client::ShardedClient;
use utoipa::ToSchema;
use validation::Validation;
@ -19,6 +20,20 @@ pub struct HubModelInfo {
pub pipeline_tag: Option<String>,
}
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct HealthResponse {}
#[derive(Clone, Debug)]
pub struct Health {
pub client: ShardedClient,
}
impl Health {
pub fn new(client: ShardedClient) -> Self {
Self { client }
}
}
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info {
/// Model info

View File

@ -3,8 +3,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
StreamDetails, StreamResponse, Token, Validation,
GenerateParameters, GenerateRequest, GenerateResponse, Health, HubModelInfo, Infer, Info,
PrefillToken, StreamDetails, StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
@ -82,9 +82,43 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
Json(info.0)
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/health",
request_body = HealthRequest,
responses(
(status = 200, description = "Everything is working fine"),
(status = 500, description = "Text generation inference is down", body = ErrorResponse,
example = json ! ({"error": "unhealthy"})),
)
)]
#[instrument]
/// Health check method
async fn health(
mut health: Extension<Health>,
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead is check if the gRPC channels are still healthy.
// Send a small inference request
health.client.health().await.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "unhealthy".to_string(),
error_type: "healthcheck".to_string(),
}),
)
})?;
Ok(axum::Json(()))
}
#[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
async fn health_generate(
infer: Extension<Infer>,
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead is check if the gRPC channels are still healthy.
@ -111,7 +145,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
},
})
.await?;
Ok(())
Ok(axum::Json(()))
}
/// Generate tokens
@ -555,6 +589,7 @@ pub async fn run(
max_input_length,
max_total_tokens,
);
let health_ext = Health::new(client.clone());
let infer = Infer::new(
client,
validation,
@ -650,6 +685,7 @@ pub async fn run(
.route("/invocations", post(compat_generate))
// Base Health route
.route("/health", get(health))
.route("/health_generate", get(health_generate))
// Inference API health route
.route("/", get(health))
// AWS Sagemaker health route
@ -657,6 +693,7 @@ pub async fn run(
// Prometheus metrics route
.route("/metrics", get(metrics))
.layer(Extension(info))
.layer(Extension(health_ext))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
.layer(Extension(prom_handle))

View File

@ -29,6 +29,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Info(self, request, context):
return self.model.info
async def Health(self, request, context):
return generate_pb2.HealthResponse()
async def ServiceDiscovery(self, request, context):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)