mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: implement infer endpoint wrapper around generate
This commit is contained in:
parent
cb69b09a77
commit
501b7c4436
@ -9,6 +9,40 @@ use tracing::warn;
|
|||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct OutputChunk {
|
||||||
|
name: String,
|
||||||
|
shape: Vec<usize>,
|
||||||
|
datatype: String,
|
||||||
|
data: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct InferenceOutput {
|
||||||
|
id: String,
|
||||||
|
outputs: Vec<OutputChunk>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct InferenceRequest {
|
||||||
|
pub id: String,
|
||||||
|
pub inputs: Vec<Input>,
|
||||||
|
pub outputs: Vec<Output>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct Input {
|
||||||
|
pub name: String,
|
||||||
|
pub shape: Vec<usize>,
|
||||||
|
pub datatype: String,
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct Output {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
pub struct LiveReponse {
|
pub struct LiveReponse {
|
||||||
pub live: bool,
|
pub live: bool,
|
||||||
|
@ -19,7 +19,10 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, ToolCall, ToolType};
|
use crate::{FunctionDefinition, ToolCall, ToolType};
|
||||||
#[cfg(feature = "kserve")]
|
#[cfg(feature = "kserve")]
|
||||||
use crate::{LiveReponse, MetadataServerResponse, ReadyResponse};
|
use crate::{
|
||||||
|
InferenceOutput, InferenceRequest, LiveReponse, MetadataServerResponse, OutputChunk,
|
||||||
|
ReadyResponse,
|
||||||
|
};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
#[cfg(feature = "kserve")]
|
#[cfg(feature = "kserve")]
|
||||||
@ -1448,36 +1451,144 @@ async fn get_v2_models_model_name_versions_model_version(
|
|||||||
Path((model_name, model_version)): Path<(String, String)>,
|
Path((model_name, model_version)): Path<(String, String)>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let data = MetadataServerResponse {
|
let data = MetadataServerResponse {
|
||||||
name: "gpt2".to_string(),
|
name: model_name,
|
||||||
version: "1.0".to_string(),
|
version: model_version,
|
||||||
extensions: vec!["infer".to_string(), "ready".to_string()],
|
extensions: vec!["infer".to_string(), "ready".to_string()],
|
||||||
};
|
};
|
||||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg(feature = "kserve")]
|
#[cfg(feature = "kserve")]
|
||||||
// async fn get_v2_models_model_name_versions_model_version_ready() -> JsonValue {
|
#[utoipa::path(
|
||||||
// let name = "gpt2";
|
post,
|
||||||
// let ready = true;
|
tag = "Text Generation Inference",
|
||||||
// json!({
|
path = "/v2/models/{model_name}/versions/{model_version}/infer",
|
||||||
// "name" : name,
|
request_body = Json<InferenceRequest>,
|
||||||
// "ready": ready,
|
responses(
|
||||||
// })
|
(status = 200, description = "Inference response", body = JsonValue),
|
||||||
// }
|
(status = 404, description = "No response", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "No response"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn post_v2_models_model_name_versions_model_version_infer(
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
|
Json(payload): Json<InferenceRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
println!("Received payload: {:?}", payload);
|
||||||
|
|
||||||
// // TODO: Implement this route and resolve the req/res types
|
// let mut output_chunks = Vec::new();
|
||||||
// #[cfg(feature = "kserve")]
|
let id = payload.id.clone();
|
||||||
// async fn post_v2_models_model_name_versions_model_version_infer() -> StatusCode {
|
|
||||||
// // $inference_request =
|
|
||||||
// // {
|
|
||||||
// // "id" : $string #optional,
|
|
||||||
// // "parameters" : $parameters #optional,
|
|
||||||
// // "inputs" : [ $request_input, ... ],
|
|
||||||
// // "outputs" : [ $request_output, ... ] #optional
|
|
||||||
// // }
|
|
||||||
|
|
||||||
// StatusCode::OK
|
let str_inputs = payload
|
||||||
// }
|
.inputs
|
||||||
|
.iter()
|
||||||
|
.map(|input| {
|
||||||
|
std::str::from_utf8(&input.data).map_err(|e| {
|
||||||
|
(
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: e.to_string(),
|
||||||
|
error_type: "utf8".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
|
let output_chunks = payload
|
||||||
|
.inputs
|
||||||
|
.iter()
|
||||||
|
.zip(payload.outputs.iter())
|
||||||
|
.zip(str_inputs.iter())
|
||||||
|
.map(|((input, output), str_input)| {
|
||||||
|
let generate_request = GenerateRequest {
|
||||||
|
inputs: str_input.to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
|
temperature: None,
|
||||||
|
repetition_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
top_k: None,
|
||||||
|
top_p: None,
|
||||||
|
typical_p: None,
|
||||||
|
do_sample: true,
|
||||||
|
max_new_tokens: Some(100),
|
||||||
|
return_full_text: None,
|
||||||
|
stop: Vec::new(),
|
||||||
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: false,
|
||||||
|
seed: None,
|
||||||
|
top_n_tokens: None,
|
||||||
|
grammar: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
async {
|
||||||
|
let span = tracing::Span::current();
|
||||||
|
generate_internal(
|
||||||
|
infer.clone(),
|
||||||
|
compute_type.clone(),
|
||||||
|
Json(generate_request),
|
||||||
|
span.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map(|(_, Json(generation))| {
|
||||||
|
let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
|
||||||
|
let output_name = output.name.clone();
|
||||||
|
let output_shape = input.shape.clone();
|
||||||
|
OutputChunk {
|
||||||
|
name: output_name,
|
||||||
|
shape: output_shape,
|
||||||
|
datatype: "BYTES".to_string(),
|
||||||
|
data: generation_as_bytes,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Incomplete generation".into(),
|
||||||
|
error_type: "Incomplete generation".into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<FuturesUnordered<_>>()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let inference_output = InferenceOutput {
|
||||||
|
id: id.clone(),
|
||||||
|
outputs: output_chunks,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((HeaderMap::new(), Json(inference_output)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/models/{model_name}/versions/{model_version}/ready",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Model ready response", body = JsonValue),
|
||||||
|
(status = 404, description = "No response", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "No response"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn get_v2_models_model_name_versions_model_version_ready(
|
||||||
|
Path((model_name, _model_version)): Path<(String, String)>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = json!({
|
||||||
|
"name": model_name,
|
||||||
|
"ready": true,
|
||||||
|
});
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub(crate) struct ComputeType(String);
|
pub(crate) struct ComputeType(String);
|
||||||
@ -1844,18 +1955,14 @@ pub async fn run(
|
|||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
paths(
|
paths(
|
||||||
|
post_v2_models_model_name_versions_model_version_infer,
|
||||||
get_v2_health_live,
|
get_v2_health_live,
|
||||||
get_v2_health_ready,
|
get_v2_health_ready,
|
||||||
get_v2,
|
get_v2,
|
||||||
get_v2_models_model_name_versions_model_version,
|
get_v2_models_model_name_versions_model_version,
|
||||||
// get_v2_models_model_name_versions_model_version_ready,
|
get_v2_models_model_name_versions_model_version_ready,
|
||||||
// post_v2_models_model_name_versions_model_version_infer
|
|
||||||
),
|
),
|
||||||
components(schemas(
|
components(schemas(LiveReponse, ReadyResponse, MetadataServerResponse,))
|
||||||
LiveReponse,
|
|
||||||
ReadyResponse,
|
|
||||||
MetadataServerResponse,
|
|
||||||
))
|
|
||||||
)]
|
)]
|
||||||
struct KServeApiDoc;
|
struct KServeApiDoc;
|
||||||
|
|
||||||
@ -1919,24 +2026,31 @@ pub async fn run(
|
|||||||
#[cfg(feature = "kserve")]
|
#[cfg(feature = "kserve")]
|
||||||
{
|
{
|
||||||
tracing::info!("Built with `kserve` feature");
|
tracing::info!("Built with `kserve` feature");
|
||||||
|
|
||||||
|
// [ ] Inference | | POST v2/models/[/versions/<model_version>]/infer $inference_request $inference_response
|
||||||
|
// [X] Model | Metadata | GET v2/models/<model_name>[/versions/<model_version>] $metadata_model_response
|
||||||
|
// [X] Server | Ready | GET v2/health/ready $ready_server_response
|
||||||
|
// [X] Server | Live | GET v2/health/live $live_server_response
|
||||||
|
// [X] Server | Metadata | GET v2 $metadata_server_response
|
||||||
|
// [ ] Model | Ready | GET v2/models/<model_name>[/versions/]/ready
|
||||||
|
|
||||||
app = app
|
app = app
|
||||||
.route("/v2/health/live", get(get_v2_health_live))
|
|
||||||
.route("/v2/health/ready", get(get_v2_health_ready))
|
|
||||||
.route("/v2", get(get_v2))
|
|
||||||
.route(
|
.route(
|
||||||
// get metadata for a model version
|
"/v2/models/:model_name/versions/:model_version/infer",
|
||||||
|
post(post_v2_models_model_name_versions_model_version_infer),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
"/v2/models/:model_name/versions/:model_version",
|
"/v2/models/:model_name/versions/:model_version",
|
||||||
get(get_v2_models_model_name_versions_model_version),
|
get(get_v2_models_model_name_versions_model_version),
|
||||||
|
)
|
||||||
|
.route("/v2/health/ready", get(get_v2_health_ready))
|
||||||
|
.route("/v2/health/live", get(get_v2_health_live))
|
||||||
|
.route("/v2", get(get_v2))
|
||||||
|
.route(
|
||||||
|
// get readiness for a model version
|
||||||
|
"/v2/models/:model_name/versions/:model_version/ready",
|
||||||
|
get(get_v2_models_model_name_versions_model_version_ready),
|
||||||
);
|
);
|
||||||
// .route(
|
|
||||||
// // get readiness for a model version
|
|
||||||
// "/v2/models/:model_name/versions/:model_version/ready",
|
|
||||||
// get(get_v2_models_model_name_versions_model_version_ready),
|
|
||||||
// )
|
|
||||||
// .route(
|
|
||||||
// "/v2/models/:model_name/versions/:model_version/infer",
|
|
||||||
// post(post_v2_models_model_name_versions_model_version_infer),
|
|
||||||
// );
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// add layers after routes
|
// add layers after routes
|
||||||
|
Loading…
Reference in New Issue
Block a user