diff --git a/router/src/lib.rs b/router/src/lib.rs index 39fa3b30..6ed25868 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -10,64 +10,61 @@ use utoipa::ToSchema; use validation::Validation; #[cfg(feature = "kserve")] -#[derive(Debug, Serialize, Deserialize, ToSchema)] -pub struct OutputChunk { - name: String, - shape: Vec, - datatype: String, - data: Vec, -} +mod kserve { + use super::*; -#[cfg(feature = "kserve")] -#[derive(Debug, Serialize, Deserialize, ToSchema)] -pub struct InferenceOutput { - id: String, - outputs: Vec, -} + #[derive(Debug, Serialize, Deserialize, ToSchema)] + pub struct OutputChunk { + pub name: String, + pub shape: Vec, + pub datatype: String, + pub data: Vec, + } -#[cfg(feature = "kserve")] -#[derive(Debug, Deserialize, ToSchema)] -pub(crate) struct InferenceRequest { - pub id: String, - #[serde(default = "default_parameters")] - pub parameters: GenerateParameters, - pub inputs: Vec, - pub outputs: Vec, -} + #[derive(Debug, Serialize, Deserialize, ToSchema)] + pub struct InferenceOutput { + pub id: String, + pub outputs: Vec, + } -#[cfg(feature = "kserve")] -#[derive(Debug, Serialize, Deserialize, ToSchema)] -pub(crate) struct Input { - pub name: String, - pub shape: Vec, - pub datatype: String, - pub data: Vec, -} + #[derive(Debug, Deserialize, ToSchema)] + pub(crate) struct InferenceRequest { + pub id: String, + #[serde(default = "default_parameters")] + pub parameters: GenerateParameters, + pub inputs: Vec, + pub outputs: Vec, + } -#[cfg(feature = "kserve")] -#[derive(Debug, Serialize, Deserialize, ToSchema)] -pub(crate) struct Output { - pub name: String, -} + #[derive(Debug, Serialize, Deserialize, ToSchema)] + pub(crate) struct Input { + pub name: String, + pub shape: Vec, + pub datatype: String, + pub data: Vec, + } -#[cfg(feature = "kserve")] -#[derive(Debug, Serialize, Deserialize, ToSchema)] -pub struct LiveReponse { - pub live: bool, -} + #[derive(Debug, Serialize, Deserialize, ToSchema)] + pub(crate) struct Output { + pub name: String, + } -#[cfg(feature = "kserve")] -#[derive(Debug, Serialize, Deserialize, ToSchema)] -pub struct ReadyResponse { - pub live: bool, -} + #[derive(Debug, Serialize, Deserialize, ToSchema)] + pub struct LiveReponse { + pub live: bool, + } -#[cfg(feature = "kserve")] -#[derive(Debug, Serialize, Deserialize, ToSchema)] -pub struct MetadataServerResponse { - pub name: String, - pub version: String, - pub extensions: Vec, + #[derive(Debug, Serialize, Deserialize, ToSchema)] + pub struct ReadyResponse { + pub live: bool, + } + + #[derive(Debug, Serialize, Deserialize, ToSchema)] + pub struct MetadataServerResponse { + pub name: String, + pub version: String, + pub extensions: Vec, + } } /// Type alias for generation responses diff --git a/router/src/server.rs b/router/src/server.rs index b157e76b..71136acf 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -4,6 +4,12 @@ use crate::infer::v2::SchedulerV2; use crate::infer::v3::SchedulerV3; use crate::infer::{HealthCheck, Scheduler}; use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; +use crate::health::Health; +#[cfg(feature = "kserve")] +use crate::kserve::{ + InferenceOutput, InferenceRequest, LiveReponse, MetadataServerResponse, OutputChunk, + ReadyResponse, +}; use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, @@ -18,11 +24,6 @@ use crate::{ CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, }; use crate::{FunctionDefinition, ToolCall, ToolType}; -#[cfg(feature = "kserve")] -use crate::{ - InferenceOutput, InferenceRequest, LiveReponse, MetadataServerResponse, OutputChunk, - ReadyResponse, -}; use async_stream::__private::AsyncStream; use axum::extract::Extension; #[cfg(feature = "kserve")] @@ -1382,13 +1383,12 @@ async fn metrics(prom_handle: Extension) -> String { tag = "Text Generation Inference", path = "/v2/health/live", responses( - (status = 200, description = "Live response", body = LiveReponse), - (status = 404, description = "No response", body = ErrorResponse, - example = json ! ({"error": "No response"})), + (status = 200, description = "Service is live", body = LiveReponse), + (status = 404, description = "Service not found", body = ErrorResponse, + example = json!({"error": "No response"})) ) - )] -// https://github.com/kserve/open-inference-protocol/blob/main/specification/protocol/inference_rest.md -async fn get_v2_health_live() -> Result)> { +)] +async fn kserve_health_live() -> Result)> { let data = LiveReponse { live: true }; Ok((HeaderMap::new(), Json(data)).into_response()) } @@ -1399,12 +1399,12 @@ async fn get_v2_health_live() -> Result Result)> { +)] +async fn kserve_health_ready() -> Result)> { let data = ReadyResponse { live: true }; Ok((HeaderMap::new(), Json(data)).into_response()) } @@ -1415,12 +1415,12 @@ async fn get_v2_health_ready() -> Result Result)> { +)] +async fn kerve_server_metadata() -> Result)> { let data = MetadataServerResponse { name: "text-generation-inference".to_string(), version: env!("CARGO_PKG_VERSION").to_string(), @@ -1439,12 +1439,12 @@ async fn get_v2() -> Result)> { tag = "Text Generation Inference", path = "/v2/models/{model_name}/versions/{model_version}", responses( - (status = 200, description = "Model ready response", body = MetadataServerResponse), - (status = 404, description = "No response", body = ErrorResponse, - example = json ! ({"error": "No response"})), + (status = 200, description = "Model version metadata retrieved", body = MetadataServerResponse), + (status = 404, description = "Model or version not found", body = ErrorResponse, + example = json!({"error": "No response"})) ) - )] -async fn get_v2_models_model_name_versions_model_version( +)] +async fn kserve_model_metadata( Path((model_name, model_version)): Path<(String, String)>, ) -> Result)> { let data = MetadataServerResponse { @@ -1462,12 +1462,12 @@ async fn get_v2_models_model_name_versions_model_version( path = "/v2/models/{model_name}/versions/{model_version}/infer", request_body = Json, responses( - (status = 200, description = "Inference response", body = InferenceOutput), - (status = 404, description = "No response", body = ErrorResponse, - example = json ! ({"error": "No response"})), + (status = 200, description = "Inference executed successfully", body = InferenceOutput), + (status = 404, description = "Model or version not found", body = ErrorResponse, + example = json!({"error": "No response"})) ) - )] -async fn post_v2_models_model_name_versions_model_version_infer( +)] +async fn kserve_model_infer( infer: Extension, Extension(compute_type): Extension, Json(payload): Json, @@ -1551,12 +1551,12 @@ async fn post_v2_models_model_name_versions_model_version_infer( tag = "Text Generation Inference", path = "/v2/models/{model_name}/versions/{model_version}/ready", responses( - (status = 200, description = "Model ready response", body = ReadyResponse), - (status = 404, description = "No response", body = ErrorResponse, - example = json ! ({"error": "No response"})), + (status = 200, description = "Model version is ready", body = ReadyResponse), + (status = 404, description = "Model or version not found", body = ErrorResponse, + example = json!({"error": "No response"})) ) - )] -async fn get_v2_models_model_name_versions_model_version_ready( +)] +async fn kserve_model_metadata_ready( Path((_model_name, _model_version)): Path<(String, String)>, ) -> Result)> { let data = ReadyResponse { live: true }; @@ -1928,12 +1928,12 @@ pub async fn run( #[derive(OpenApi)] #[openapi( paths( - post_v2_models_model_name_versions_model_version_infer, - get_v2_health_live, - get_v2_health_ready, - get_v2, - get_v2_models_model_name_versions_model_version, - get_v2_models_model_name_versions_model_version_ready, + kserve_model_infer, + kserve_health_live, + kserve_health_ready, + kerve_server_metadata, + kserve_model_metadata, + kserve_model_metadata_ready, ), components(schemas(LiveReponse, ReadyResponse, MetadataServerResponse,)) )] @@ -2002,18 +2002,18 @@ pub async fn run( app = app .route( "/v2/models/:model_name/versions/:model_version/infer", - post(post_v2_models_model_name_versions_model_version_infer), + post(kserve_model_infer), ) .route( "/v2/models/:model_name/versions/:model_version", - get(get_v2_models_model_name_versions_model_version), + get(kserve_model_metadata), ) - .route("/v2/health/ready", get(get_v2_health_ready)) - .route("/v2/health/live", get(get_v2_health_live)) - .route("/v2", get(get_v2)) + .route("/v2/health/ready", get(kserve_health_ready)) + .route("/v2/health/live", get(kserve_health_live)) + .route("/v2", get(kerve_server_metadata)) .route( "/v2/models/:model_name/versions/:model_version/ready", - get(get_v2_models_model_name_versions_model_version_ready), + get(kserve_model_metadata_ready), ); }