feat: implement infer endpoint wrapper around generate

This commit is contained in:
drbh 2024-05-23 01:00:55 +00:00
parent cb69b09a77
commit 501b7c4436
2 changed files with 192 additions and 44 deletions

View File

@ -9,6 +9,40 @@ use tracing::warn;
use utoipa::ToSchema;
use validation::Validation;
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct OutputChunk {
name: String,
shape: Vec<usize>,
datatype: String,
data: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct InferenceOutput {
id: String,
outputs: Vec<OutputChunk>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct InferenceRequest {
pub id: String,
pub inputs: Vec<Input>,
pub outputs: Vec<Output>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct Input {
pub name: String,
pub shape: Vec<usize>,
pub datatype: String,
pub data: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct Output {
pub name: String,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct LiveReponse {
pub live: bool,

View File

@ -19,7 +19,10 @@ use crate::{
};
use crate::{FunctionDefinition, ToolCall, ToolType};
#[cfg(feature = "kserve")]
use crate::{LiveReponse, MetadataServerResponse, ReadyResponse};
use crate::{
InferenceOutput, InferenceRequest, LiveReponse, MetadataServerResponse, OutputChunk,
ReadyResponse,
};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
#[cfg(feature = "kserve")]
@ -1448,36 +1451,144 @@ async fn get_v2_models_model_name_versions_model_version(
Path((model_name, model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = MetadataServerResponse {
name: "gpt2".to_string(),
version: "1.0".to_string(),
name: model_name,
version: model_version,
extensions: vec!["infer".to_string(), "ready".to_string()],
};
Ok((HeaderMap::new(), Json(data)).into_response())
}
// #[cfg(feature = "kserve")]
// async fn get_v2_models_model_name_versions_model_version_ready() -> JsonValue {
// let name = "gpt2";
// let ready = true;
// json!({
// "name" : name,
// "ready": ready,
// })
// }
#[cfg(feature = "kserve")]
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/infer",
request_body = Json<InferenceRequest>,
responses(
(status = 200, description = "Inference response", body = JsonValue),
(status = 404, description = "No response", body = ErrorResponse,
example = json ! ({"error": "No response"})),
)
)]
async fn post_v2_models_model_name_versions_model_version_infer(
infer: Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(payload): Json<InferenceRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
println!("Received payload: {:?}", payload);
// // TODO: Implement this route and resolve the req/res types
// #[cfg(feature = "kserve")]
// async fn post_v2_models_model_name_versions_model_version_infer() -> StatusCode {
// // $inference_request =
// // {
// // "id" : $string #optional,
// // "parameters" : $parameters #optional,
// // "inputs" : [ $request_input, ... ],
// // "outputs" : [ $request_output, ... ] #optional
// // }
// let mut output_chunks = Vec::new();
let id = payload.id.clone();
// StatusCode::OK
// }
let str_inputs = payload
.inputs
.iter()
.map(|input| {
std::str::from_utf8(&input.data).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "utf8".to_string(),
}),
)
})
})
.collect::<Result<Vec<_>, _>>()?;
let output_chunks = payload
.inputs
.iter()
.zip(payload.outputs.iter())
.zip(str_inputs.iter())
.map(|((input, output), str_input)| {
let generate_request = GenerateRequest {
inputs: str_input.to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: None,
repetition_penalty: None,
frequency_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: true,
max_new_tokens: Some(100),
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: true,
decoder_input_details: false,
seed: None,
top_n_tokens: None,
grammar: None,
},
};
async {
let span = tracing::Span::current();
generate_internal(
infer.clone(),
compute_type.clone(),
Json(generate_request),
span.clone(),
)
.await
.map(|(_, Json(generation))| {
let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
let output_name = output.name.clone();
let output_shape = input.shape.clone();
OutputChunk {
name: output_name,
shape: output_shape,
datatype: "BYTES".to_string(),
data: generation_as_bytes,
}
})
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
}
})
.collect::<FuturesUnordered<_>>()
.try_collect::<Vec<_>>()
.await?;
let inference_output = InferenceOutput {
id: id.clone(),
outputs: output_chunks,
};
Ok((HeaderMap::new(), Json(inference_output)).into_response())
}
#[cfg(feature = "kserve")]
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/ready",
responses(
(status = 200, description = "Model ready response", body = JsonValue),
(status = 404, description = "No response", body = ErrorResponse,
example = json ! ({"error": "No response"})),
)
)]
async fn get_v2_models_model_name_versions_model_version_ready(
Path((model_name, _model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = json!({
"name": model_name,
"ready": true,
});
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[derive(Clone, Debug)]
pub(crate) struct ComputeType(String);
@ -1844,18 +1955,14 @@ pub async fn run(
#[derive(OpenApi)]
#[openapi(
paths(
post_v2_models_model_name_versions_model_version_infer,
get_v2_health_live,
get_v2_health_ready,
get_v2,
get_v2_models_model_name_versions_model_version,
// get_v2_models_model_name_versions_model_version_ready,
// post_v2_models_model_name_versions_model_version_infer
get_v2_models_model_name_versions_model_version_ready,
),
components(schemas(
LiveReponse,
ReadyResponse,
MetadataServerResponse,
))
components(schemas(LiveReponse, ReadyResponse, MetadataServerResponse,))
)]
struct KServeApiDoc;
@ -1919,24 +2026,31 @@ pub async fn run(
#[cfg(feature = "kserve")]
{
tracing::info!("Built with `kserve` feature");
// [ ] Inference | | POST v2/models/[/versions/<model_version>]/infer $inference_request $inference_response
// [X] Model | Metadata | GET v2/models/<model_name>[/versions/<model_version>] $metadata_model_response
// [X] Server | Ready | GET v2/health/ready $ready_server_response
// [X] Server | Live | GET v2/health/live $live_server_response
// [X] Server | Metadata | GET v2 $metadata_server_response
// [ ] Model | Ready | GET v2/models/<model_name>[/versions/]/ready
app = app
.route("/v2/health/live", get(get_v2_health_live))
.route("/v2/health/ready", get(get_v2_health_ready))
.route("/v2", get(get_v2))
.route(
// get metadata for a model version
"/v2/models/:model_name/versions/:model_version/infer",
post(post_v2_models_model_name_versions_model_version_infer),
)
.route(
"/v2/models/:model_name/versions/:model_version",
get(get_v2_models_model_name_versions_model_version),
)
.route("/v2/health/ready", get(get_v2_health_ready))
.route("/v2/health/live", get(get_v2_health_live))
.route("/v2", get(get_v2))
.route(
// get readiness for a model version
"/v2/models/:model_name/versions/:model_version/ready",
get(get_v2_models_model_name_versions_model_version_ready),
);
// .route(
// // get readiness for a model version
// "/v2/models/:model_name/versions/:model_version/ready",
// get(get_v2_models_model_name_versions_model_version_ready),
// )
// .route(
// "/v2/models/:model_name/versions/:model_version/infer",
// post(post_v2_models_model_name_versions_model_version_infer),
// );
}
// add layers after routes