diff --git a/router/src/server.rs b/router/src/server.rs index f4e2447f..b157e76b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1489,12 +1489,20 @@ async fn post_v2_models_model_name_versions_model_version_infer( }) .collect::, _>>()?; - let output_chunks = payload - .inputs + if str_inputs.len() != payload.outputs.len() { + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Inputs and outputs length mismatch".to_string(), + error_type: "length mismatch".to_string(), + }), + )); + } + + let output_chunks = str_inputs .iter() .zip(&payload.outputs) - .zip(&str_inputs) - .map(|((input, output), str_input)| { + .map(|(str_input, output)| { let generate_request = GenerateRequest { inputs: str_input.to_string(), parameters: payload.parameters.clone(), @@ -1509,7 +1517,7 @@ async fn post_v2_models_model_name_versions_model_version_infer( let generation_as_bytes = generation.generated_text.as_bytes().to_vec(); OutputChunk { name: output.name.clone(), - shape: input.shape.clone(), + shape: vec![1, generation_as_bytes.len()], datatype: "BYTES".to_string(), data: generation_as_bytes, }