fix: refactor and improve types

This commit is contained in:
drbh 2024-05-23 16:11:39 +00:00
parent 501b7c4436
commit 0f1c4b12ca
2 changed files with 49 additions and 83 deletions

View File

@ -9,6 +9,7 @@ use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct OutputChunk { pub struct OutputChunk {
name: String, name: String,
@ -17,42 +18,51 @@ pub struct OutputChunk {
data: Vec<u8>, data: Vec<u8>,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct InferenceOutput { pub struct InferenceOutput {
id: String, id: String,
outputs: Vec<OutputChunk>, outputs: Vec<OutputChunk>,
} }
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[cfg(feature = "kserve")]
pub struct InferenceRequest { #[derive(Debug, Deserialize, ToSchema)]
pub(crate) struct InferenceRequest {
pub id: String, pub id: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
pub inputs: Vec<Input>, pub inputs: Vec<Input>,
pub outputs: Vec<Output>, pub outputs: Vec<Output>,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct Input { pub(crate) struct Input {
pub name: String, pub name: String,
pub shape: Vec<usize>, pub shape: Vec<usize>,
pub datatype: String, pub datatype: String,
pub data: Vec<u8>, pub data: Vec<u8>,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct Output { pub(crate) struct Output {
pub name: String, pub name: String,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct LiveReponse { pub struct LiveReponse {
pub live: bool, pub live: bool,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct ReadyResponse { pub struct ReadyResponse {
pub live: bool, pub live: bool,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct MetadataServerResponse { pub struct MetadataServerResponse {
pub name: String, pub name: String,

View File

@ -1376,16 +1376,13 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render() prom_handle.render()
} }
#[cfg(feature = "kserve")]
use serde_json::json;
#[cfg(feature = "kserve")] #[cfg(feature = "kserve")]
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/v2/health/live", path = "/v2/health/live",
responses( responses(
(status = 200, description = "Live response", body = JsonValue), (status = 200, description = "Live response", body = LiveReponse),
(status = 404, description = "No response", body = ErrorResponse, (status = 404, description = "No response", body = ErrorResponse,
example = json ! ({"error": "No response"})), example = json ! ({"error": "No response"})),
) )
@ -1402,7 +1399,7 @@ async fn get_v2_health_live() -> Result<Response, (StatusCode, Json<ErrorRespons
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/v2/health/ready", path = "/v2/health/ready",
responses( responses(
(status = 200, description = "Ready response", body = JsonValue), (status = 200, description = "Ready response", body = ReadyResponse),
(status = 404, description = "No response", body = ErrorResponse, (status = 404, description = "No response", body = ErrorResponse,
example = json ! ({"error": "No response"})), example = json ! ({"error": "No response"})),
) )
@ -1418,7 +1415,7 @@ async fn get_v2_health_ready() -> Result<Response, (StatusCode, Json<ErrorRespon
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/v2", path = "/v2",
responses( responses(
(status = 200, description = "Metadata response", body = JsonValue), (status = 200, description = "Metadata response", body = MetadataServerResponse),
(status = 404, description = "No response", body = ErrorResponse, (status = 404, description = "No response", body = ErrorResponse,
example = json ! ({"error": "No response"})), example = json ! ({"error": "No response"})),
) )
@ -1442,7 +1439,7 @@ async fn get_v2() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}", path = "/v2/models/{model_name}/versions/{model_version}",
responses( responses(
(status = 200, description = "Model ready response", body = JsonValue), (status = 200, description = "Model ready response", body = MetadataServerResponse),
(status = 404, description = "No response", body = ErrorResponse, (status = 404, description = "No response", body = ErrorResponse,
example = json ! ({"error": "No response"})), example = json ! ({"error": "No response"})),
) )
@ -1465,7 +1462,7 @@ async fn get_v2_models_model_name_versions_model_version(
path = "/v2/models/{model_name}/versions/{model_version}/infer", path = "/v2/models/{model_name}/versions/{model_version}/infer",
request_body = Json<InferenceRequest>, request_body = Json<InferenceRequest>,
responses( responses(
(status = 200, description = "Inference response", body = JsonValue), (status = 200, description = "Inference response", body = InferenceOutput),
(status = 404, description = "No response", body = ErrorResponse, (status = 404, description = "No response", body = ErrorResponse,
example = json ! ({"error": "No response"})), example = json ! ({"error": "No response"})),
) )
@ -1475,11 +1472,7 @@ async fn post_v2_models_model_name_versions_model_version_infer(
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Json(payload): Json<InferenceRequest>, Json(payload): Json<InferenceRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
println!("Received payload: {:?}", payload);
// let mut output_chunks = Vec::new();
let id = payload.id.clone(); let id = payload.id.clone();
let str_inputs = payload let str_inputs = payload
.inputs .inputs
.iter() .iter()
@ -1499,62 +1492,37 @@ async fn post_v2_models_model_name_versions_model_version_infer(
let output_chunks = payload let output_chunks = payload
.inputs .inputs
.iter() .iter()
.zip(payload.outputs.iter()) .zip(&payload.outputs)
.zip(str_inputs.iter()) .zip(&str_inputs)
.map(|((input, output), str_input)| { .map(|((input, output), str_input)| {
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: str_input.to_string(), inputs: str_input.to_string(),
parameters: GenerateParameters { parameters: payload.parameters.clone(),
best_of: None,
temperature: None,
repetition_penalty: None,
frequency_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: true,
max_new_tokens: Some(100),
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: true,
decoder_input_details: false,
seed: None,
top_n_tokens: None,
grammar: None,
},
}; };
let infer = infer.clone();
async { let compute_type = compute_type.clone();
let span = tracing::Span::current(); let span = tracing::Span::current();
generate_internal( async move {
infer.clone(), generate_internal(infer, compute_type, Json(generate_request), span)
compute_type.clone(), .await
Json(generate_request), .map(|(_, Json(generation))| {
span.clone(), let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
) OutputChunk {
.await name: output.name.clone(),
.map(|(_, Json(generation))| { shape: input.shape.clone(),
let generation_as_bytes = generation.generated_text.as_bytes().to_vec(); datatype: "BYTES".to_string(),
let output_name = output.name.clone(); data: generation_as_bytes,
let output_shape = input.shape.clone(); }
OutputChunk { })
name: output_name, .map_err(|_| {
shape: output_shape, (
datatype: "BYTES".to_string(), StatusCode::INTERNAL_SERVER_ERROR,
data: generation_as_bytes, Json(ErrorResponse {
} error: "Incomplete generation".into(),
}) error_type: "Incomplete generation".into(),
.map_err(|_| { }),
( )
StatusCode::INTERNAL_SERVER_ERROR, })
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
} }
}) })
.collect::<FuturesUnordered<_>>() .collect::<FuturesUnordered<_>>()
@ -1575,18 +1543,15 @@ async fn post_v2_models_model_name_versions_model_version_infer(
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/ready", path = "/v2/models/{model_name}/versions/{model_version}/ready",
responses( responses(
(status = 200, description = "Model ready response", body = JsonValue), (status = 200, description = "Model ready response", body = ReadyResponse),
(status = 404, description = "No response", body = ErrorResponse, (status = 404, description = "No response", body = ErrorResponse,
example = json ! ({"error": "No response"})), example = json ! ({"error": "No response"})),
) )
)] )]
async fn get_v2_models_model_name_versions_model_version_ready( async fn get_v2_models_model_name_versions_model_version_ready(
Path((model_name, _model_version)): Path<(String, String)>, Path((_model_name, _model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = json!({ let data = ReadyResponse { live: true };
"name": model_name,
"ready": true,
});
Ok((HeaderMap::new(), Json(data)).into_response()) Ok((HeaderMap::new(), Json(data)).into_response())
} }
@ -2026,14 +1991,6 @@ pub async fn run(
#[cfg(feature = "kserve")] #[cfg(feature = "kserve")]
{ {
tracing::info!("Built with `kserve` feature"); tracing::info!("Built with `kserve` feature");
// [ ] Inference | | POST v2/models/[/versions/<model_version>]/infer $inference_request $inference_response
// [X] Model | Metadata | GET v2/models/<model_name>[/versions/<model_version>] $metadata_model_response
// [X] Server | Ready | GET v2/health/ready $ready_server_response
// [X] Server | Live | GET v2/health/live $live_server_response
// [X] Server | Metadata | GET v2 $metadata_server_response
// [ ] Model | Ready | GET v2/models/<model_name>[/versions/]/ready
app = app app = app
.route( .route(
"/v2/models/:model_name/versions/:model_version/infer", "/v2/models/:model_name/versions/:model_version/infer",
@ -2047,7 +2004,6 @@ pub async fn run(
.route("/v2/health/live", get(get_v2_health_live)) .route("/v2/health/live", get(get_v2_health_live))
.route("/v2", get(get_v2)) .route("/v2", get(get_v2))
.route( .route(
// get readiness for a model version
"/v2/models/:model_name/versions/:model_version/ready", "/v2/models/:model_name/versions/:model_version/ready",
get(get_v2_models_model_name_versions_model_version_ready), get(get_v2_models_model_name_versions_model_version_ready),
); );