From 0f1c4b12ca05e4cca93db6b6562fd6f44d0b325f Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 23 May 2024 16:11:39 +0000 Subject: [PATCH] fix: refactor and improve types --- router/src/lib.rs | 18 +++++-- router/src/server.rs | 114 +++++++++++++------------------------------ 2 files changed, 49 insertions(+), 83 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index aa6959ee..39fa3b30 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -9,6 +9,7 @@ use tracing::warn; use utoipa::ToSchema; use validation::Validation; +#[cfg(feature = "kserve")] #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct OutputChunk { name: String, @@ -17,42 +18,51 @@ pub struct OutputChunk { data: Vec, } +#[cfg(feature = "kserve")] #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct InferenceOutput { id: String, outputs: Vec, } -#[derive(Debug, Serialize, Deserialize, ToSchema)] -pub struct InferenceRequest { +#[cfg(feature = "kserve")] +#[derive(Debug, Deserialize, ToSchema)] +pub(crate) struct InferenceRequest { pub id: String, + #[serde(default = "default_parameters")] + pub parameters: GenerateParameters, pub inputs: Vec, pub outputs: Vec, } +#[cfg(feature = "kserve")] #[derive(Debug, Serialize, Deserialize, ToSchema)] -pub struct Input { +pub(crate) struct Input { pub name: String, pub shape: Vec, pub datatype: String, pub data: Vec, } +#[cfg(feature = "kserve")] #[derive(Debug, Serialize, Deserialize, ToSchema)] -pub struct Output { +pub(crate) struct Output { pub name: String, } +#[cfg(feature = "kserve")] #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct LiveReponse { pub live: bool, } +#[cfg(feature = "kserve")] #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct ReadyResponse { pub live: bool, } +#[cfg(feature = "kserve")] #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct MetadataServerResponse { pub name: String, diff --git a/router/src/server.rs b/router/src/server.rs index d2a29a05..f4e2447f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1376,16 +1376,13 @@ async fn metrics(prom_handle: Extension) -> String { prom_handle.render() } -#[cfg(feature = "kserve")] -use serde_json::json; - #[cfg(feature = "kserve")] #[utoipa::path( post, tag = "Text Generation Inference", path = "/v2/health/live", responses( - (status = 200, description = "Live response", body = JsonValue), + (status = 200, description = "Live response", body = LiveReponse), (status = 404, description = "No response", body = ErrorResponse, example = json ! ({"error": "No response"})), ) @@ -1402,7 +1399,7 @@ async fn get_v2_health_live() -> Result Result Result)> { tag = "Text Generation Inference", path = "/v2/models/{model_name}/versions/{model_version}", responses( - (status = 200, description = "Model ready response", body = JsonValue), + (status = 200, description = "Model ready response", body = MetadataServerResponse), (status = 404, description = "No response", body = ErrorResponse, example = json ! ({"error": "No response"})), ) @@ -1465,7 +1462,7 @@ async fn get_v2_models_model_name_versions_model_version( path = "/v2/models/{model_name}/versions/{model_version}/infer", request_body = Json, responses( - (status = 200, description = "Inference response", body = JsonValue), + (status = 200, description = "Inference response", body = InferenceOutput), (status = 404, description = "No response", body = ErrorResponse, example = json ! ({"error": "No response"})), ) @@ -1475,11 +1472,7 @@ async fn post_v2_models_model_name_versions_model_version_infer( Extension(compute_type): Extension, Json(payload): Json, ) -> Result)> { - println!("Received payload: {:?}", payload); - - // let mut output_chunks = Vec::new(); let id = payload.id.clone(); - let str_inputs = payload .inputs .iter() @@ -1499,62 +1492,37 @@ async fn post_v2_models_model_name_versions_model_version_infer( let output_chunks = payload .inputs .iter() - .zip(payload.outputs.iter()) - .zip(str_inputs.iter()) + .zip(&payload.outputs) + .zip(&str_inputs) .map(|((input, output), str_input)| { let generate_request = GenerateRequest { inputs: str_input.to_string(), - parameters: GenerateParameters { - best_of: None, - temperature: None, - repetition_penalty: None, - frequency_penalty: None, - top_k: None, - top_p: None, - typical_p: None, - do_sample: true, - max_new_tokens: Some(100), - return_full_text: None, - stop: Vec::new(), - truncate: None, - watermark: false, - details: true, - decoder_input_details: false, - seed: None, - top_n_tokens: None, - grammar: None, - }, + parameters: payload.parameters.clone(), }; - - async { - let span = tracing::Span::current(); - generate_internal( - infer.clone(), - compute_type.clone(), - Json(generate_request), - span.clone(), - ) - .await - .map(|(_, Json(generation))| { - let generation_as_bytes = generation.generated_text.as_bytes().to_vec(); - let output_name = output.name.clone(); - let output_shape = input.shape.clone(); - OutputChunk { - name: output_name, - shape: output_shape, - datatype: "BYTES".to_string(), - data: generation_as_bytes, - } - }) - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Incomplete generation".into(), - error_type: "Incomplete generation".into(), - }), - ) - }) + let infer = infer.clone(); + let compute_type = compute_type.clone(); + let span = tracing::Span::current(); + async move { + generate_internal(infer, compute_type, Json(generate_request), span) + .await + .map(|(_, Json(generation))| { + let generation_as_bytes = generation.generated_text.as_bytes().to_vec(); + OutputChunk { + name: output.name.clone(), + shape: input.shape.clone(), + datatype: "BYTES".to_string(), + data: generation_as_bytes, + } + }) + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".into(), + error_type: "Incomplete generation".into(), + }), + ) + }) } }) .collect::>() @@ -1575,18 +1543,15 @@ async fn post_v2_models_model_name_versions_model_version_infer( tag = "Text Generation Inference", path = "/v2/models/{model_name}/versions/{model_version}/ready", responses( - (status = 200, description = "Model ready response", body = JsonValue), + (status = 200, description = "Model ready response", body = ReadyResponse), (status = 404, description = "No response", body = ErrorResponse, example = json ! ({"error": "No response"})), ) )] async fn get_v2_models_model_name_versions_model_version_ready( - Path((model_name, _model_version)): Path<(String, String)>, + Path((_model_name, _model_version)): Path<(String, String)>, ) -> Result)> { - let data = json!({ - "name": model_name, - "ready": true, - }); + let data = ReadyResponse { live: true }; Ok((HeaderMap::new(), Json(data)).into_response()) } @@ -2026,14 +1991,6 @@ pub async fn run( #[cfg(feature = "kserve")] { tracing::info!("Built with `kserve` feature"); - - // [ ] Inference | | POST v2/models/[/versions/]/infer $inference_request $inference_response - // [X] Model | Metadata | GET v2/models/[/versions/] $metadata_model_response - // [X] Server | Ready | GET v2/health/ready $ready_server_response - // [X] Server | Live | GET v2/health/live $live_server_response - // [X] Server | Metadata | GET v2 $metadata_server_response - // [ ] Model | Ready | GET v2/models/[/versions/]/ready - app = app .route( "/v2/models/:model_name/versions/:model_version/infer", @@ -2047,7 +2004,6 @@ pub async fn run( .route("/v2/health/live", get(get_v2_health_live)) .route("/v2", get(get_v2)) .route( - // get readiness for a model version "/v2/models/:model_name/versions/:model_version/ready", get(get_v2_models_model_name_versions_model_version_ready), );