From cb69b09a778f52fca5a170e5cb80148ca79ddcaf Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 22 May 2024 21:13:30 +0000 Subject: [PATCH] feat: add kserve feature and basic routes --- router/Cargo.toml | 1 + router/src/lib.rs | 24 +++++++ router/src/server.rs | 165 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 187 insertions(+), 3 deletions(-) diff --git a/router/Cargo.toml b/router/Cargo.toml index 2e6264be..e9e1b292 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -58,3 +58,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } default = ["ngrok"] ngrok = ["dep:ngrok"] google = [] +kserve = [] diff --git a/router/src/lib.rs b/router/src/lib.rs index b6902c49..a6bbe00b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -9,6 +9,30 @@ use tracing::warn; use utoipa::ToSchema; use validation::Validation; +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct LiveReponse { + pub live: bool, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct ReadyResponse { + pub live: bool, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct MetadataServerResponse { + pub name: String, + pub version: String, + pub extensions: Vec, +} + +/// Type alias for generation responses +pub(crate) type GenerateStreamResponse = ( + OwnedSemaphorePermit, + u32, // input_length + UnboundedReceiverStream>, +); + #[derive(Clone, Deserialize, ToSchema)] pub(crate) struct VertexInstance { #[schema(example = "What is Deep Learning?")] diff --git a/router/src/server.rs b/router/src/server.rs index 30479b0e..2adb865b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -18,8 +18,12 @@ use crate::{ CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, }; use crate::{FunctionDefinition, ToolCall, ToolType}; +#[cfg(feature = "kserve")] +use crate::{LiveReponse, MetadataServerResponse, ReadyResponse}; use async_stream::__private::AsyncStream; use axum::extract::Extension; +#[cfg(feature = "kserve")] +use axum::extract::Path; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; @@ -1369,6 +1373,112 @@ async fn metrics(prom_handle: Extension) -> String { prom_handle.render() } +#[cfg(feature = "kserve")] +use serde_json::json; + +#[cfg(feature = "kserve")] +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v2/health/live", + responses( + (status = 200, description = "Live response", body = JsonValue), + (status = 404, description = "No response", body = ErrorResponse, + example = json ! ({"error": "No response"})), + ) + )] +// https://github.com/kserve/open-inference-protocol/blob/main/specification/protocol/inference_rest.md +async fn get_v2_health_live() -> Result)> { + let data = LiveReponse { live: true }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} + +#[cfg(feature = "kserve")] +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v2/health/ready", + responses( + (status = 200, description = "Ready response", body = JsonValue), + (status = 404, description = "No response", body = ErrorResponse, + example = json ! ({"error": "No response"})), + ) + )] +async fn get_v2_health_ready() -> Result)> { + let data = ReadyResponse { live: true }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} + +#[cfg(feature = "kserve")] +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/v2", + responses( + (status = 200, description = "Metadata response", body = JsonValue), + (status = 404, description = "No response", body = ErrorResponse, + example = json ! ({"error": "No response"})), + ) + )] +async fn get_v2() -> Result)> { + let data = MetadataServerResponse { + name: "text-generation-inference".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + extensions: vec![ + "health".to_string(), + "models".to_string(), + "metrics".to_string(), + ], + }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} + +#[cfg(feature = "kserve")] +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/v2/models/{model_name}/versions/{model_version}", + responses( + (status = 200, description = "Model ready response", body = JsonValue), + (status = 404, description = "No response", body = ErrorResponse, + example = json ! ({"error": "No response"})), + ) + )] +async fn get_v2_models_model_name_versions_model_version( + Path((model_name, model_version)): Path<(String, String)>, +) -> Result)> { + let data = MetadataServerResponse { + name: "gpt2".to_string(), + version: "1.0".to_string(), + extensions: vec!["infer".to_string(), "ready".to_string()], + }; + Ok((HeaderMap::new(), Json(data)).into_response()) +} + +// #[cfg(feature = "kserve")] +// async fn get_v2_models_model_name_versions_model_version_ready() -> JsonValue { +// let name = "gpt2"; +// let ready = true; +// json!({ +// "name" : name, +// "ready": ready, +// }) +// } + +// // TODO: Implement this route and resolve the req/res types +// #[cfg(feature = "kserve")] +// async fn post_v2_models_model_name_versions_model_version_infer() -> StatusCode { +// // $inference_request = +// // { +// // "id" : $string #optional, +// // "parameters" : $parameters #optional, +// // "inputs" : [ $request_input, ... ], +// // "outputs" : [ $request_output, ... ] #optional +// // } + +// StatusCode::OK +// } + #[derive(Clone, Debug)] pub(crate) struct ComputeType(String); @@ -1711,7 +1821,6 @@ pub async fn run( // Define VertextApiDoc conditionally only if the "google" feature is enabled let doc = { - // avoid `mut` if possible #[cfg(feature = "google")] { use crate::VertexInstance; @@ -1723,13 +1832,40 @@ pub async fn run( )] struct VertextApiDoc; - // limiting mutability to the smallest scope necessary let mut doc = ApiDoc::openapi(); doc.merge(VertextApiDoc::openapi()); doc } #[cfg(not(feature = "google"))] - ApiDoc::openapi() + { + // Define KServeApiDoc conditionally only if the "kserve" feature is enabled + #[cfg(feature = "kserve")] + { + #[derive(OpenApi)] + #[openapi( + paths( + get_v2_health_live, + get_v2_health_ready, + get_v2, + get_v2_models_model_name_versions_model_version, + // get_v2_models_model_name_versions_model_version_ready, + // post_v2_models_model_name_versions_model_version_infer + ), + components(schemas( + LiveReponse, + ReadyResponse, + MetadataServerResponse, + )) + )] + struct KServeApiDoc; + + let mut doc = ApiDoc::openapi(); + doc.merge(KServeApiDoc::openapi()); + doc + } + #[cfg(not(feature = "kserve"))] + ApiDoc::openapi() + } }; // Configure Swagger UI @@ -1780,6 +1916,29 @@ pub async fn run( } } + #[cfg(feature = "kserve")] + { + tracing::info!("Built with `kserve` feature"); + app = app + .route("/v2/health/live", get(get_v2_health_live)) + .route("/v2/health/ready", get(get_v2_health_ready)) + .route("/v2", get(get_v2)) + .route( + // get metadata for a model version + "/v2/models/:model_name/versions/:model_version", + get(get_v2_models_model_name_versions_model_version), + ); + // .route( + // // get readiness for a model version + // "/v2/models/:model_name/versions/:model_version/ready", + // get(get_v2_models_model_name_versions_model_version_ready), + // ) + // .route( + // "/v2/models/:model_name/versions/:model_version/infer", + // post(post_v2_models_model_name_versions_model_version_infer), + // ); + } + // add layers after routes app = app .layer(Extension(info))