From 501b7c44361eb3ca5d5452af065711f7d7777198 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 23 May 2024 01:00:55 +0000 Subject: [PATCH] feat: implement infer endpoint wrapper around generate --- router/src/lib.rs | 34 ++++++++ router/src/server.rs | 202 +++++++++++++++++++++++++++++++++---------- 2 files changed, 192 insertions(+), 44 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index a6bbe00b..aa6959ee 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -9,6 +9,40 @@ use tracing::warn; use utoipa::ToSchema; use validation::Validation; +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct OutputChunk { + name: String, + shape: Vec, + datatype: String, + data: Vec, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct InferenceOutput { + id: String, + outputs: Vec, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct InferenceRequest { + pub id: String, + pub inputs: Vec, + pub outputs: Vec, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct Input { + pub name: String, + pub shape: Vec, + pub datatype: String, + pub data: Vec, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct Output { + pub name: String, +} + #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct LiveReponse { pub live: bool, diff --git a/router/src/server.rs b/router/src/server.rs index 2adb865b..d2a29a05 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -19,7 +19,10 @@ use crate::{ }; use crate::{FunctionDefinition, ToolCall, ToolType}; #[cfg(feature = "kserve")] -use crate::{LiveReponse, MetadataServerResponse, ReadyResponse}; +use crate::{ + InferenceOutput, InferenceRequest, LiveReponse, MetadataServerResponse, OutputChunk, + ReadyResponse, +}; use async_stream::__private::AsyncStream; use axum::extract::Extension; #[cfg(feature = "kserve")] @@ -1448,36 +1451,144 @@ async fn get_v2_models_model_name_versions_model_version( Path((model_name, model_version)): Path<(String, String)>, ) -> Result)> { let data = MetadataServerResponse { - name: "gpt2".to_string(), - version: "1.0".to_string(), + name: model_name, + version: model_version, extensions: vec!["infer".to_string(), "ready".to_string()], }; Ok((HeaderMap::new(), Json(data)).into_response()) } -// #[cfg(feature = "kserve")] -// async fn get_v2_models_model_name_versions_model_version_ready() -> JsonValue { -// let name = "gpt2"; -// let ready = true; -// json!({ -// "name" : name, -// "ready": ready, -// }) -// } +#[cfg(feature = "kserve")] +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v2/models/{model_name}/versions/{model_version}/infer", + request_body = Json, + responses( + (status = 200, description = "Inference response", body = JsonValue), + (status = 404, description = "No response", body = ErrorResponse, + example = json ! ({"error": "No response"})), + ) + )] +async fn post_v2_models_model_name_versions_model_version_infer( + infer: Extension, + Extension(compute_type): Extension, + Json(payload): Json, +) -> Result)> { + println!("Received payload: {:?}", payload); -// // TODO: Implement this route and resolve the req/res types -// #[cfg(feature = "kserve")] -// async fn post_v2_models_model_name_versions_model_version_infer() -> StatusCode { -// // $inference_request = -// // { -// // "id" : $string #optional, -// // "parameters" : $parameters #optional, -// // "inputs" : [ $request_input, ... ], -// // "outputs" : [ $request_output, ... ] #optional -// // } + // let mut output_chunks = Vec::new(); + let id = payload.id.clone(); -// StatusCode::OK -// } + let str_inputs = payload + .inputs + .iter() + .map(|input| { + std::str::from_utf8(&input.data).map_err(|e| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: e.to_string(), + error_type: "utf8".to_string(), + }), + ) + }) + }) + .collect::, _>>()?; + + let output_chunks = payload + .inputs + .iter() + .zip(payload.outputs.iter()) + .zip(str_inputs.iter()) + .map(|((input, output), str_input)| { + let generate_request = GenerateRequest { + inputs: str_input.to_string(), + parameters: GenerateParameters { + best_of: None, + temperature: None, + repetition_penalty: None, + frequency_penalty: None, + top_k: None, + top_p: None, + typical_p: None, + do_sample: true, + max_new_tokens: Some(100), + return_full_text: None, + stop: Vec::new(), + truncate: None, + watermark: false, + details: true, + decoder_input_details: false, + seed: None, + top_n_tokens: None, + grammar: None, + }, + }; + + async { + let span = tracing::Span::current(); + generate_internal( + infer.clone(), + compute_type.clone(), + Json(generate_request), + span.clone(), + ) + .await + .map(|(_, Json(generation))| { + let generation_as_bytes = generation.generated_text.as_bytes().to_vec(); + let output_name = output.name.clone(); + let output_shape = input.shape.clone(); + OutputChunk { + name: output_name, + shape: output_shape, + datatype: "BYTES".to_string(), + data: generation_as_bytes, + } + }) + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".into(), + error_type: "Incomplete generation".into(), + }), + ) + }) + } + }) + .collect::>() + .try_collect::>() + .await?; + + let inference_output = InferenceOutput { + id: id.clone(), + outputs: output_chunks, + }; + + Ok((HeaderMap::new(), Json(inference_output)).into_response()) +} + +#[cfg(feature = "kserve")] +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/v2/models/{model_name}/versions/{model_version}/ready", + responses( + (status = 200, description = "Model ready response", body = JsonValue), + (status = 404, description = "No response", body = ErrorResponse, + example = json ! ({"error": "No response"})), + ) + )] +async fn get_v2_models_model_name_versions_model_version_ready( + Path((model_name, _model_version)): Path<(String, String)>, +) -> Result)> { + let data = json!({ + "name": model_name, + "ready": true, + }); + Ok((HeaderMap::new(), Json(data)).into_response()) +} #[derive(Clone, Debug)] pub(crate) struct ComputeType(String); @@ -1844,18 +1955,14 @@ pub async fn run( #[derive(OpenApi)] #[openapi( paths( + post_v2_models_model_name_versions_model_version_infer, get_v2_health_live, get_v2_health_ready, get_v2, get_v2_models_model_name_versions_model_version, - // get_v2_models_model_name_versions_model_version_ready, - // post_v2_models_model_name_versions_model_version_infer + get_v2_models_model_name_versions_model_version_ready, ), - components(schemas( - LiveReponse, - ReadyResponse, - MetadataServerResponse, - )) + components(schemas(LiveReponse, ReadyResponse, MetadataServerResponse,)) )] struct KServeApiDoc; @@ -1919,24 +2026,31 @@ pub async fn run( #[cfg(feature = "kserve")] { tracing::info!("Built with `kserve` feature"); + + // [ ] Inference | | POST v2/models/[/versions/]/infer $inference_request $inference_response + // [X] Model | Metadata | GET v2/models/[/versions/] $metadata_model_response + // [X] Server | Ready | GET v2/health/ready $ready_server_response + // [X] Server | Live | GET v2/health/live $live_server_response + // [X] Server | Metadata | GET v2 $metadata_server_response + // [ ] Model | Ready | GET v2/models/[/versions/]/ready + app = app - .route("/v2/health/live", get(get_v2_health_live)) - .route("/v2/health/ready", get(get_v2_health_ready)) - .route("/v2", get(get_v2)) .route( - // get metadata for a model version + "/v2/models/:model_name/versions/:model_version/infer", + post(post_v2_models_model_name_versions_model_version_infer), + ) + .route( "/v2/models/:model_name/versions/:model_version", get(get_v2_models_model_name_versions_model_version), + ) + .route("/v2/health/ready", get(get_v2_health_ready)) + .route("/v2/health/live", get(get_v2_health_live)) + .route("/v2", get(get_v2)) + .route( + // get readiness for a model version + "/v2/models/:model_name/versions/:model_version/ready", + get(get_v2_models_model_name_versions_model_version_ready), ); - // .route( - // // get readiness for a model version - // "/v2/models/:model_name/versions/:model_version/ready", - // get(get_v2_models_model_name_versions_model_version_ready), - // ) - // .route( - // "/v2/models/:model_name/versions/:model_version/infer", - // post(post_v2_models_model_name_versions_model_version_infer), - // ); } // add layers after routes