fix: improve infer and simplify

This commit is contained in:
drbh 2024-05-23 16:21:42 +00:00
parent 0f1c4b12ca
commit 01bd1b2c26

View File

@ -1489,12 +1489,20 @@ async fn post_v2_models_model_name_versions_model_version_infer(
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let output_chunks = payload if str_inputs.len() != payload.outputs.len() {
.inputs return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Inputs and outputs length mismatch".to_string(),
error_type: "length mismatch".to_string(),
}),
));
}
let output_chunks = str_inputs
.iter() .iter()
.zip(&payload.outputs) .zip(&payload.outputs)
.zip(&str_inputs) .map(|(str_input, output)| {
.map(|((input, output), str_input)| {
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: str_input.to_string(), inputs: str_input.to_string(),
parameters: payload.parameters.clone(), parameters: payload.parameters.clone(),
@ -1509,7 +1517,7 @@ async fn post_v2_models_model_name_versions_model_version_infer(
let generation_as_bytes = generation.generated_text.as_bytes().to_vec(); let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
OutputChunk { OutputChunk {
name: output.name.clone(), name: output.name.clone(),
shape: input.shape.clone(), shape: vec![1, generation_as_bytes.len()],
datatype: "BYTES".to_string(), datatype: "BYTES".to_string(),
data: generation_as_bytes, data: generation_as_bytes,
} }