mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: add kserve feature and basic routes
This commit is contained in:
parent
2a48a10043
commit
cb69b09a77
@ -58,3 +58,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
|||||||
default = ["ngrok"]
|
default = ["ngrok"]
|
||||||
ngrok = ["dep:ngrok"]
|
ngrok = ["dep:ngrok"]
|
||||||
google = []
|
google = []
|
||||||
|
kserve = []
|
||||||
|
@ -9,6 +9,30 @@ use tracing::warn;
|
|||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct LiveReponse {
|
||||||
|
pub live: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct ReadyResponse {
|
||||||
|
pub live: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct MetadataServerResponse {
|
||||||
|
pub name: String,
|
||||||
|
pub version: String,
|
||||||
|
pub extensions: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Type alias for generation responses
|
||||||
|
pub(crate) type GenerateStreamResponse = (
|
||||||
|
OwnedSemaphorePermit,
|
||||||
|
u32, // input_length
|
||||||
|
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||||
|
);
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
pub(crate) struct VertexInstance {
|
pub(crate) struct VertexInstance {
|
||||||
#[schema(example = "What is Deep Learning?")]
|
#[schema(example = "What is Deep Learning?")]
|
||||||
|
@ -18,8 +18,12 @@ use crate::{
|
|||||||
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, ToolCall, ToolType};
|
use crate::{FunctionDefinition, ToolCall, ToolType};
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
use crate::{LiveReponse, MetadataServerResponse, ReadyResponse};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
use axum::extract::Path;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
use axum::response::{IntoResponse, Response};
|
use axum::response::{IntoResponse, Response};
|
||||||
@ -1369,6 +1373,112 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
|||||||
prom_handle.render()
|
prom_handle.render()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/health/live",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Live response", body = JsonValue),
|
||||||
|
(status = 404, description = "No response", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "No response"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
// https://github.com/kserve/open-inference-protocol/blob/main/specification/protocol/inference_rest.md
|
||||||
|
async fn get_v2_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = LiveReponse { live: true };
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/health/ready",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Ready response", body = JsonValue),
|
||||||
|
(status = 404, description = "No response", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "No response"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn get_v2_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = ReadyResponse { live: true };
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Metadata response", body = JsonValue),
|
||||||
|
(status = 404, description = "No response", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "No response"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn get_v2() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = MetadataServerResponse {
|
||||||
|
name: "text-generation-inference".to_string(),
|
||||||
|
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||||
|
extensions: vec![
|
||||||
|
"health".to_string(),
|
||||||
|
"models".to_string(),
|
||||||
|
"metrics".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/models/{model_name}/versions/{model_version}",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Model ready response", body = JsonValue),
|
||||||
|
(status = 404, description = "No response", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "No response"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn get_v2_models_model_name_versions_model_version(
|
||||||
|
Path((model_name, model_version)): Path<(String, String)>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = MetadataServerResponse {
|
||||||
|
name: "gpt2".to_string(),
|
||||||
|
version: "1.0".to_string(),
|
||||||
|
extensions: vec!["infer".to_string(), "ready".to_string()],
|
||||||
|
};
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
// #[cfg(feature = "kserve")]
|
||||||
|
// async fn get_v2_models_model_name_versions_model_version_ready() -> JsonValue {
|
||||||
|
// let name = "gpt2";
|
||||||
|
// let ready = true;
|
||||||
|
// json!({
|
||||||
|
// "name" : name,
|
||||||
|
// "ready": ready,
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // TODO: Implement this route and resolve the req/res types
|
||||||
|
// #[cfg(feature = "kserve")]
|
||||||
|
// async fn post_v2_models_model_name_versions_model_version_infer() -> StatusCode {
|
||||||
|
// // $inference_request =
|
||||||
|
// // {
|
||||||
|
// // "id" : $string #optional,
|
||||||
|
// // "parameters" : $parameters #optional,
|
||||||
|
// // "inputs" : [ $request_input, ... ],
|
||||||
|
// // "outputs" : [ $request_output, ... ] #optional
|
||||||
|
// // }
|
||||||
|
|
||||||
|
// StatusCode::OK
|
||||||
|
// }
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub(crate) struct ComputeType(String);
|
pub(crate) struct ComputeType(String);
|
||||||
|
|
||||||
@ -1711,7 +1821,6 @@ pub async fn run(
|
|||||||
|
|
||||||
// Define VertextApiDoc conditionally only if the "google" feature is enabled
|
// Define VertextApiDoc conditionally only if the "google" feature is enabled
|
||||||
let doc = {
|
let doc = {
|
||||||
// avoid `mut` if possible
|
|
||||||
#[cfg(feature = "google")]
|
#[cfg(feature = "google")]
|
||||||
{
|
{
|
||||||
use crate::VertexInstance;
|
use crate::VertexInstance;
|
||||||
@ -1723,13 +1832,40 @@ pub async fn run(
|
|||||||
)]
|
)]
|
||||||
struct VertextApiDoc;
|
struct VertextApiDoc;
|
||||||
|
|
||||||
// limiting mutability to the smallest scope necessary
|
|
||||||
let mut doc = ApiDoc::openapi();
|
let mut doc = ApiDoc::openapi();
|
||||||
doc.merge(VertextApiDoc::openapi());
|
doc.merge(VertextApiDoc::openapi());
|
||||||
doc
|
doc
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "google"))]
|
#[cfg(not(feature = "google"))]
|
||||||
|
{
|
||||||
|
// Define KServeApiDoc conditionally only if the "kserve" feature is enabled
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
{
|
||||||
|
#[derive(OpenApi)]
|
||||||
|
#[openapi(
|
||||||
|
paths(
|
||||||
|
get_v2_health_live,
|
||||||
|
get_v2_health_ready,
|
||||||
|
get_v2,
|
||||||
|
get_v2_models_model_name_versions_model_version,
|
||||||
|
// get_v2_models_model_name_versions_model_version_ready,
|
||||||
|
// post_v2_models_model_name_versions_model_version_infer
|
||||||
|
),
|
||||||
|
components(schemas(
|
||||||
|
LiveReponse,
|
||||||
|
ReadyResponse,
|
||||||
|
MetadataServerResponse,
|
||||||
|
))
|
||||||
|
)]
|
||||||
|
struct KServeApiDoc;
|
||||||
|
|
||||||
|
let mut doc = ApiDoc::openapi();
|
||||||
|
doc.merge(KServeApiDoc::openapi());
|
||||||
|
doc
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "kserve"))]
|
||||||
ApiDoc::openapi()
|
ApiDoc::openapi()
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Configure Swagger UI
|
// Configure Swagger UI
|
||||||
@ -1780,6 +1916,29 @@ pub async fn run(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
{
|
||||||
|
tracing::info!("Built with `kserve` feature");
|
||||||
|
app = app
|
||||||
|
.route("/v2/health/live", get(get_v2_health_live))
|
||||||
|
.route("/v2/health/ready", get(get_v2_health_ready))
|
||||||
|
.route("/v2", get(get_v2))
|
||||||
|
.route(
|
||||||
|
// get metadata for a model version
|
||||||
|
"/v2/models/:model_name/versions/:model_version",
|
||||||
|
get(get_v2_models_model_name_versions_model_version),
|
||||||
|
);
|
||||||
|
// .route(
|
||||||
|
// // get readiness for a model version
|
||||||
|
// "/v2/models/:model_name/versions/:model_version/ready",
|
||||||
|
// get(get_v2_models_model_name_versions_model_version_ready),
|
||||||
|
// )
|
||||||
|
// .route(
|
||||||
|
// "/v2/models/:model_name/versions/:model_version/infer",
|
||||||
|
// post(post_v2_models_model_name_versions_model_version_infer),
|
||||||
|
// );
|
||||||
|
}
|
||||||
|
|
||||||
// add layers after routes
|
// add layers after routes
|
||||||
app = app
|
app = app
|
||||||
.layer(Extension(info))
|
.layer(Extension(info))
|
||||||
|
Loading…
Reference in New Issue
Block a user