use crate::infer::Infer; use crate::{ default_parameters, server::{generate_internal, ComputeType}, Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Serialize, ToSchema, }; use axum::extract::{Extension, Path}; use axum::http::{HeaderMap, StatusCode}; use axum::response::IntoResponse; use axum::Json; use futures::stream::FuturesUnordered; use futures::TryStreamExt; #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct OutputChunk { pub name: String, pub shape: Vec, pub datatype: String, pub data: Vec, } #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct InferenceOutput { pub id: String, pub outputs: Vec, } #[derive(Debug, Deserialize, ToSchema)] pub(crate) struct InferenceRequest { pub id: String, #[serde(default = "default_parameters")] pub parameters: GenerateParameters, pub inputs: Vec, pub outputs: Vec, } #[derive(Debug, Serialize, Deserialize, ToSchema)] pub(crate) struct Input { pub name: String, pub shape: Vec, pub datatype: String, pub data: Vec, } #[derive(Debug, Serialize, Deserialize, ToSchema)] pub(crate) struct Output { pub name: String, } #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct LiveResponse { pub live: bool, } #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct ReadyResponse { pub live: bool, } #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct MetadataServerResponse { pub name: String, pub version: String, pub extensions: Vec, } #[utoipa::path( post, tag = "Text Generation Inference", path = "/v2/health/live", responses( (status = 200, description = "Service is live", body = LiveReponse), (status = 404, description = "Service not found", body = ErrorResponse, example = json!({"error": "No response"})) ) )] pub async fn kserve_health_live() -> Json { let data = LiveResponse { live: true }; Json(data) } #[utoipa::path( get, tag = "Text Generation Inference", path = "/v2/health/ready", responses( (status = 200, description = "Service is ready", body = ReadyResponse), (status = 404, description = "Service not found", body = ErrorResponse, example = json!({"error": "No response"})) ) )] pub async fn kserve_health_ready() -> Json { let data = ReadyResponse { live: true }; Json(data) } #[utoipa::path( get, tag = "Text Generation Inference", path = "/v2", responses( (status = 200, description = "Metadata retrieved", body = MetadataServerResponse), (status = 404, description = "Service not found", body = ErrorResponse, example = json!({"error": "No response"})) ) )] pub async fn kerve_server_metadata() -> Json { let data = MetadataServerResponse { name: "text-generation-inference".to_string(), version: env!("CARGO_PKG_VERSION").to_string(), extensions: vec![ "health".to_string(), "models".to_string(), "metrics".to_string(), ], }; Json(data) } #[utoipa::path( get, tag = "Text Generation Inference", path = "/v2/models/{model_name}/versions/{model_version}", responses( (status = 200, description = "Model version metadata retrieved", body = MetadataServerResponse), (status = 404, description = "Model or version not found", body = ErrorResponse, example = json!({"error": "No response"})) ) )] pub async fn kserve_model_metadata( Path((model_name, model_version)): Path<(String, String)>, ) -> Json { let data = MetadataServerResponse { name: model_name, version: model_version, extensions: vec!["infer".to_string(), "ready".to_string()], }; Json(data) } #[utoipa::path( get, tag = "Text Generation Inference", path = "/v2/models/{model_name}/versions/{model_version}/ready", responses( (status = 200, description = "Model version is ready", body = ReadyResponse), (status = 404, description = "Model or version not found", body = ErrorResponse, example = json!({"error": "No response"})) ) )] pub async fn kserve_model_metadata_ready( Path((_model_name, _model_version)): Path<(String, String)>, ) -> Json { let data = ReadyResponse { live: true }; Json(data) } #[utoipa::path( post, tag = "Text Generation Inference", path = "/v2/models/{model_name}/versions/{model_version}/infer", request_body = Json, responses( (status = 200, description = "Inference executed successfully", body = InferenceOutput), (status = 404, description = "Model or version not found", body = ErrorResponse, example = json!({"error": "No response"})) ) )] pub async fn kserve_model_infer( infer: Extension, Extension(compute_type): Extension, Json(payload): Json, ) -> Result)> { let id = payload.id.clone(); let str_inputs = payload .inputs .iter() .map(|input| { std::str::from_utf8(&input.data).map_err(|e| { ( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { error: e.to_string(), error_type: "utf8".to_string(), }), ) }) }) .collect::, _>>()?; if str_inputs.len() != payload.outputs.len() { return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { error: "Inputs and outputs length mismatch".to_string(), error_type: "length mismatch".to_string(), }), )); } let output_chunks = str_inputs .iter() .zip(&payload.outputs) .map(|(str_input, output)| { let generate_request = GenerateRequest { inputs: str_input.to_string(), parameters: payload.parameters.clone(), add_special_tokens: true, }; let infer = infer.clone(); let compute_type = compute_type.clone(); let span = tracing::Span::current(); async move { generate_internal(infer, compute_type, Json(generate_request), span) .await .map(|(_, _, Json(generation))| { let generation_as_bytes = generation.generated_text.as_bytes().to_vec(); OutputChunk { name: output.name.clone(), shape: vec![1, generation_as_bytes.len()], datatype: "BYTES".to_string(), data: generation_as_bytes, } }) .map_err(|_| { ( StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { error: "Incomplete generation".into(), error_type: "Incomplete generation".into(), }), ) }) } }) .collect::>() .try_collect::>() .await?; let inference_output = InferenceOutput { id: id.clone(), outputs: output_chunks, }; Ok((HeaderMap::new(), Json(inference_output))) }