diff --git a/proto/generate.proto b/proto/generate.proto index ad47409e..894d7bc1 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -15,8 +15,13 @@ service TextGenerationService { rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Decode token for a list of prefilled batches rpc Decode (DecodeRequest) returns (DecodeResponse); + /// Health check + rpc Health (HealthRequest) returns (HealthResponse); } +message HealthRequest {} +message HealthResponse {} + /// Empty request message InfoRequest {} @@ -173,4 +178,4 @@ message DecodeResponse { repeated Generation generations = 1; /// Next batch (cached) optional Batch batch = 2; -} \ No newline at end of file +} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 7cadf430..bf1b6b58 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -7,7 +7,7 @@ use tonic::transport::{Channel, Uri}; use tracing::instrument; /// Text Generation Inference gRPC client -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Client { stub: TextGenerationServiceClient, } @@ -62,6 +62,14 @@ impl Client { Ok(response) } + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + /// Clear the past generations cache #[instrument(skip(self))] pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 6a001306..401082c5 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -6,6 +6,7 @@ mod pb; mod sharded_client; pub use client::Client; +pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 469d75f6..2f57a437 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,10 +1,11 @@ /// Multi shard Client use crate::Result; -use crate::{Batch, Client, Generation, Request, ShardInfo}; +use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; use futures::future::join_all; use tonic::transport::Uri; use tracing::instrument; +#[derive(Debug, Clone)] /// Text Generation Inference gRPC multi client pub struct ShardedClient { clients: Vec, @@ -48,6 +49,17 @@ impl ShardedClient { join_all(futures).await.pop().unwrap() } + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + /// Clear the past generations cache #[instrument(skip(self))] pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { diff --git a/router/src/lib.rs b/router/src/lib.rs index 7a1707d9..4f73fa16 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -7,6 +7,7 @@ mod validation; use infer::Infer; use queue::{Entry, Queue}; use serde::{Deserialize, Serialize}; +use text_generation_client::ShardedClient; use utoipa::ToSchema; use validation::Validation; @@ -19,6 +20,20 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct HealthResponse {} + +#[derive(Clone, Debug)] +pub struct Health { + pub client: ShardedClient, +} + +impl Health { + pub fn new(client: ShardedClient) -> Self { + Self { client } + } +} + #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info diff --git a/router/src/server.rs b/router/src/server.rs index 9540ba18..ce8c59dc 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -3,8 +3,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, - StreamDetails, StreamResponse, Token, Validation, + GenerateParameters, GenerateRequest, GenerateResponse, Health, HubModelInfo, Infer, Info, + PrefillToken, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -82,9 +82,43 @@ async fn get_model_info(info: Extension) -> Json { Json(info.0) } +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/health", + request_body = HealthRequest, + responses( + (status = 200, description = "Everything is working fine"), + (status = 500, description = "Text generation inference is down", body = ErrorResponse, + example = json ! ({"error": "unhealthy"})), + ) +)] +#[instrument] /// Health check method +async fn health( + mut health: Extension, +) -> Result, (StatusCode, Json)> { + // TODO: while this is the best health check we can do, it is a bit on the heavy side and might + // be a bit too slow for a health check. + // What we should do instead is check if the gRPC channels are still healthy. + + // Send a small inference request + health.client.health().await.map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "unhealthy".to_string(), + error_type: "healthcheck".to_string(), + }), + ) + })?; + Ok(axum::Json(())) +} + #[instrument(skip(infer))] -async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { +async fn health_generate( + infer: Extension, +) -> Result, (StatusCode, Json)> { // TODO: while this is the best health check we can do, it is a bit on the heavy side and might // be a bit too slow for a health check. // What we should do instead is check if the gRPC channels are still healthy. @@ -111,7 +145,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json