mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: refactor and improve types
This commit is contained in:
parent
501b7c4436
commit
0f1c4b12ca
@ -9,6 +9,7 @@ use tracing::warn;
|
||||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct OutputChunk {
|
||||
name: String,
|
||||
@ -17,42 +18,51 @@ pub struct OutputChunk {
|
||||
data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct InferenceOutput {
|
||||
id: String,
|
||||
outputs: Vec<OutputChunk>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct InferenceRequest {
|
||||
#[cfg(feature = "kserve")]
|
||||
#[derive(Debug, Deserialize, ToSchema)]
|
||||
pub(crate) struct InferenceRequest {
|
||||
pub id: String,
|
||||
#[serde(default = "default_parameters")]
|
||||
pub parameters: GenerateParameters,
|
||||
pub inputs: Vec<Input>,
|
||||
pub outputs: Vec<Output>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct Input {
|
||||
pub(crate) struct Input {
|
||||
pub name: String,
|
||||
pub shape: Vec<usize>,
|
||||
pub datatype: String,
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct Output {
|
||||
pub(crate) struct Output {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct LiveReponse {
|
||||
pub live: bool,
|
||||
}
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct ReadyResponse {
|
||||
pub live: bool,
|
||||
}
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct MetadataServerResponse {
|
||||
pub name: String,
|
||||
|
@ -1376,16 +1376,13 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||
prom_handle.render()
|
||||
}
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
use serde_json::json;
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2/health/live",
|
||||
responses(
|
||||
(status = 200, description = "Live response", body = JsonValue),
|
||||
(status = 200, description = "Live response", body = LiveReponse),
|
||||
(status = 404, description = "No response", body = ErrorResponse,
|
||||
example = json ! ({"error": "No response"})),
|
||||
)
|
||||
@ -1402,7 +1399,7 @@ async fn get_v2_health_live() -> Result<Response, (StatusCode, Json<ErrorRespons
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2/health/ready",
|
||||
responses(
|
||||
(status = 200, description = "Ready response", body = JsonValue),
|
||||
(status = 200, description = "Ready response", body = ReadyResponse),
|
||||
(status = 404, description = "No response", body = ErrorResponse,
|
||||
example = json ! ({"error": "No response"})),
|
||||
)
|
||||
@ -1418,7 +1415,7 @@ async fn get_v2_health_ready() -> Result<Response, (StatusCode, Json<ErrorRespon
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2",
|
||||
responses(
|
||||
(status = 200, description = "Metadata response", body = JsonValue),
|
||||
(status = 200, description = "Metadata response", body = MetadataServerResponse),
|
||||
(status = 404, description = "No response", body = ErrorResponse,
|
||||
example = json ! ({"error": "No response"})),
|
||||
)
|
||||
@ -1442,7 +1439,7 @@ async fn get_v2() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2/models/{model_name}/versions/{model_version}",
|
||||
responses(
|
||||
(status = 200, description = "Model ready response", body = JsonValue),
|
||||
(status = 200, description = "Model ready response", body = MetadataServerResponse),
|
||||
(status = 404, description = "No response", body = ErrorResponse,
|
||||
example = json ! ({"error": "No response"})),
|
||||
)
|
||||
@ -1465,7 +1462,7 @@ async fn get_v2_models_model_name_versions_model_version(
|
||||
path = "/v2/models/{model_name}/versions/{model_version}/infer",
|
||||
request_body = Json<InferenceRequest>,
|
||||
responses(
|
||||
(status = 200, description = "Inference response", body = JsonValue),
|
||||
(status = 200, description = "Inference response", body = InferenceOutput),
|
||||
(status = 404, description = "No response", body = ErrorResponse,
|
||||
example = json ! ({"error": "No response"})),
|
||||
)
|
||||
@ -1475,11 +1472,7 @@ async fn post_v2_models_model_name_versions_model_version_infer(
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Json(payload): Json<InferenceRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
println!("Received payload: {:?}", payload);
|
||||
|
||||
// let mut output_chunks = Vec::new();
|
||||
let id = payload.id.clone();
|
||||
|
||||
let str_inputs = payload
|
||||
.inputs
|
||||
.iter()
|
||||
@ -1499,62 +1492,37 @@ async fn post_v2_models_model_name_versions_model_version_infer(
|
||||
let output_chunks = payload
|
||||
.inputs
|
||||
.iter()
|
||||
.zip(payload.outputs.iter())
|
||||
.zip(str_inputs.iter())
|
||||
.zip(&payload.outputs)
|
||||
.zip(&str_inputs)
|
||||
.map(|((input, output), str_input)| {
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: str_input.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: None,
|
||||
repetition_penalty: None,
|
||||
frequency_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
max_new_tokens: Some(100),
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: false,
|
||||
seed: None,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
},
|
||||
parameters: payload.parameters.clone(),
|
||||
};
|
||||
|
||||
async {
|
||||
let span = tracing::Span::current();
|
||||
generate_internal(
|
||||
infer.clone(),
|
||||
compute_type.clone(),
|
||||
Json(generate_request),
|
||||
span.clone(),
|
||||
)
|
||||
.await
|
||||
.map(|(_, Json(generation))| {
|
||||
let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
|
||||
let output_name = output.name.clone();
|
||||
let output_shape = input.shape.clone();
|
||||
OutputChunk {
|
||||
name: output_name,
|
||||
shape: output_shape,
|
||||
datatype: "BYTES".to_string(),
|
||||
data: generation_as_bytes,
|
||||
}
|
||||
})
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
let infer = infer.clone();
|
||||
let compute_type = compute_type.clone();
|
||||
let span = tracing::Span::current();
|
||||
async move {
|
||||
generate_internal(infer, compute_type, Json(generate_request), span)
|
||||
.await
|
||||
.map(|(_, Json(generation))| {
|
||||
let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
|
||||
OutputChunk {
|
||||
name: output.name.clone(),
|
||||
shape: input.shape.clone(),
|
||||
datatype: "BYTES".to_string(),
|
||||
data: generation_as_bytes,
|
||||
}
|
||||
})
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect::<FuturesUnordered<_>>()
|
||||
@ -1575,18 +1543,15 @@ async fn post_v2_models_model_name_versions_model_version_infer(
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2/models/{model_name}/versions/{model_version}/ready",
|
||||
responses(
|
||||
(status = 200, description = "Model ready response", body = JsonValue),
|
||||
(status = 200, description = "Model ready response", body = ReadyResponse),
|
||||
(status = 404, description = "No response", body = ErrorResponse,
|
||||
example = json ! ({"error": "No response"})),
|
||||
)
|
||||
)]
|
||||
async fn get_v2_models_model_name_versions_model_version_ready(
|
||||
Path((model_name, _model_version)): Path<(String, String)>,
|
||||
Path((_model_name, _model_version)): Path<(String, String)>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let data = json!({
|
||||
"name": model_name,
|
||||
"ready": true,
|
||||
});
|
||||
let data = ReadyResponse { live: true };
|
||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||
}
|
||||
|
||||
@ -2026,14 +1991,6 @@ pub async fn run(
|
||||
#[cfg(feature = "kserve")]
|
||||
{
|
||||
tracing::info!("Built with `kserve` feature");
|
||||
|
||||
// [ ] Inference | | POST v2/models/[/versions/<model_version>]/infer $inference_request $inference_response
|
||||
// [X] Model | Metadata | GET v2/models/<model_name>[/versions/<model_version>] $metadata_model_response
|
||||
// [X] Server | Ready | GET v2/health/ready $ready_server_response
|
||||
// [X] Server | Live | GET v2/health/live $live_server_response
|
||||
// [X] Server | Metadata | GET v2 $metadata_server_response
|
||||
// [ ] Model | Ready | GET v2/models/<model_name>[/versions/]/ready
|
||||
|
||||
app = app
|
||||
.route(
|
||||
"/v2/models/:model_name/versions/:model_version/infer",
|
||||
@ -2047,7 +2004,6 @@ pub async fn run(
|
||||
.route("/v2/health/live", get(get_v2_health_live))
|
||||
.route("/v2", get(get_v2))
|
||||
.route(
|
||||
// get readiness for a model version
|
||||
"/v2/models/:model_name/versions/:model_version/ready",
|
||||
get(get_v2_models_model_name_versions_model_version_ready),
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user