fix: cleanup and improve api docs

This commit is contained in:
drbh 2024-05-27 03:03:41 +00:00
parent 01bd1b2c26
commit 7488b982fa
2 changed files with 96 additions and 99 deletions

View File

@ -10,22 +10,23 @@ use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
#[cfg(feature = "kserve")] #[cfg(feature = "kserve")]
mod kserve {
use super::*;
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct OutputChunk { pub struct OutputChunk {
name: String, pub name: String,
shape: Vec<usize>, pub shape: Vec<usize>,
datatype: String, pub datatype: String,
data: Vec<u8>, pub data: Vec<u8>,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct InferenceOutput { pub struct InferenceOutput {
id: String, pub id: String,
outputs: Vec<OutputChunk>, pub outputs: Vec<OutputChunk>,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Deserialize, ToSchema)] #[derive(Debug, Deserialize, ToSchema)]
pub(crate) struct InferenceRequest { pub(crate) struct InferenceRequest {
pub id: String, pub id: String,
@ -35,7 +36,6 @@ pub(crate) struct InferenceRequest {
pub outputs: Vec<Output>, pub outputs: Vec<Output>,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub(crate) struct Input { pub(crate) struct Input {
pub name: String, pub name: String,
@ -44,31 +44,28 @@ pub(crate) struct Input {
pub data: Vec<u8>, pub data: Vec<u8>,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub(crate) struct Output { pub(crate) struct Output {
pub name: String, pub name: String,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct LiveReponse { pub struct LiveReponse {
pub live: bool, pub live: bool,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct ReadyResponse { pub struct ReadyResponse {
pub live: bool, pub live: bool,
} }
#[cfg(feature = "kserve")]
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct MetadataServerResponse { pub struct MetadataServerResponse {
pub name: String, pub name: String,
pub version: String, pub version: String,
pub extensions: Vec<String>, pub extensions: Vec<String>,
} }
}
/// Type alias for generation responses /// Type alias for generation responses
pub(crate) type GenerateStreamResponse = ( pub(crate) type GenerateStreamResponse = (

View File

@ -4,6 +4,12 @@ use crate::infer::v2::SchedulerV2;
use crate::infer::v3::SchedulerV3; use crate::infer::v3::SchedulerV3;
use crate::infer::{HealthCheck, Scheduler}; use crate::infer::{HealthCheck, Scheduler};
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
use crate::health::Health;
#[cfg(feature = "kserve")]
use crate::kserve::{
InferenceOutput, InferenceRequest, LiveReponse, MetadataServerResponse, OutputChunk,
ReadyResponse,
};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
@ -18,11 +24,6 @@ use crate::{
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
}; };
use crate::{FunctionDefinition, ToolCall, ToolType}; use crate::{FunctionDefinition, ToolCall, ToolType};
#[cfg(feature = "kserve")]
use crate::{
InferenceOutput, InferenceRequest, LiveReponse, MetadataServerResponse, OutputChunk,
ReadyResponse,
};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
#[cfg(feature = "kserve")] #[cfg(feature = "kserve")]
@ -1382,13 +1383,12 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/v2/health/live", path = "/v2/health/live",
responses( responses(
(status = 200, description = "Live response", body = LiveReponse), (status = 200, description = "Service is live", body = LiveReponse),
(status = 404, description = "No response", body = ErrorResponse, (status = 404, description = "Service not found", body = ErrorResponse,
example = json ! ({"error": "No response"})), example = json!({"error": "No response"}))
) )
)] )]
// https://github.com/kserve/open-inference-protocol/blob/main/specification/protocol/inference_rest.md async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
async fn get_v2_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = LiveReponse { live: true }; let data = LiveReponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response()) Ok((HeaderMap::new(), Json(data)).into_response())
} }
@ -1399,12 +1399,12 @@ async fn get_v2_health_live() -> Result<Response, (StatusCode, Json<ErrorRespons
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/v2/health/ready", path = "/v2/health/ready",
responses( responses(
(status = 200, description = "Ready response", body = ReadyResponse), (status = 200, description = "Service is ready", body = ReadyResponse),
(status = 404, description = "No response", body = ErrorResponse, (status = 404, description = "Service not found", body = ErrorResponse,
example = json ! ({"error": "No response"})), example = json!({"error": "No response"}))
) )
)] )]
async fn get_v2_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> { async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = ReadyResponse { live: true }; let data = ReadyResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response()) Ok((HeaderMap::new(), Json(data)).into_response())
} }
@ -1415,12 +1415,12 @@ async fn get_v2_health_ready() -> Result<Response, (StatusCode, Json<ErrorRespon
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/v2", path = "/v2",
responses( responses(
(status = 200, description = "Metadata response", body = MetadataServerResponse), (status = 200, description = "Metadata retrieved", body = MetadataServerResponse),
(status = 404, description = "No response", body = ErrorResponse, (status = 404, description = "Service not found", body = ErrorResponse,
example = json ! ({"error": "No response"})), example = json!({"error": "No response"}))
) )
)] )]
async fn get_v2() -> Result<Response, (StatusCode, Json<ErrorResponse>)> { async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = MetadataServerResponse { let data = MetadataServerResponse {
name: "text-generation-inference".to_string(), name: "text-generation-inference".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(), version: env!("CARGO_PKG_VERSION").to_string(),
@ -1439,12 +1439,12 @@ async fn get_v2() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}", path = "/v2/models/{model_name}/versions/{model_version}",
responses( responses(
(status = 200, description = "Model ready response", body = MetadataServerResponse), (status = 200, description = "Model version metadata retrieved", body = MetadataServerResponse),
(status = 404, description = "No response", body = ErrorResponse, (status = 404, description = "Model or version not found", body = ErrorResponse,
example = json ! ({"error": "No response"})), example = json!({"error": "No response"}))
) )
)] )]
async fn get_v2_models_model_name_versions_model_version( async fn kserve_model_metadata(
Path((model_name, model_version)): Path<(String, String)>, Path((model_name, model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = MetadataServerResponse { let data = MetadataServerResponse {
@ -1462,12 +1462,12 @@ async fn get_v2_models_model_name_versions_model_version(
path = "/v2/models/{model_name}/versions/{model_version}/infer", path = "/v2/models/{model_name}/versions/{model_version}/infer",
request_body = Json<InferenceRequest>, request_body = Json<InferenceRequest>,
responses( responses(
(status = 200, description = "Inference response", body = InferenceOutput), (status = 200, description = "Inference executed successfully", body = InferenceOutput),
(status = 404, description = "No response", body = ErrorResponse, (status = 404, description = "Model or version not found", body = ErrorResponse,
example = json ! ({"error": "No response"})), example = json!({"error": "No response"}))
) )
)] )]
async fn post_v2_models_model_name_versions_model_version_infer( async fn kserve_model_infer(
infer: Extension<Infer>, infer: Extension<Infer>,
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Json(payload): Json<InferenceRequest>, Json(payload): Json<InferenceRequest>,
@ -1551,12 +1551,12 @@ async fn post_v2_models_model_name_versions_model_version_infer(
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/ready", path = "/v2/models/{model_name}/versions/{model_version}/ready",
responses( responses(
(status = 200, description = "Model ready response", body = ReadyResponse), (status = 200, description = "Model version is ready", body = ReadyResponse),
(status = 404, description = "No response", body = ErrorResponse, (status = 404, description = "Model or version not found", body = ErrorResponse,
example = json ! ({"error": "No response"})), example = json!({"error": "No response"}))
) )
)] )]
async fn get_v2_models_model_name_versions_model_version_ready( async fn kserve_model_metadata_ready(
Path((_model_name, _model_version)): Path<(String, String)>, Path((_model_name, _model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = ReadyResponse { live: true }; let data = ReadyResponse { live: true };
@ -1928,12 +1928,12 @@ pub async fn run(
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
paths( paths(
post_v2_models_model_name_versions_model_version_infer, kserve_model_infer,
get_v2_health_live, kserve_health_live,
get_v2_health_ready, kserve_health_ready,
get_v2, kerve_server_metadata,
get_v2_models_model_name_versions_model_version, kserve_model_metadata,
get_v2_models_model_name_versions_model_version_ready, kserve_model_metadata_ready,
), ),
components(schemas(LiveReponse, ReadyResponse, MetadataServerResponse,)) components(schemas(LiveReponse, ReadyResponse, MetadataServerResponse,))
)] )]
@ -2002,18 +2002,18 @@ pub async fn run(
app = app app = app
.route( .route(
"/v2/models/:model_name/versions/:model_version/infer", "/v2/models/:model_name/versions/:model_version/infer",
post(post_v2_models_model_name_versions_model_version_infer), post(kserve_model_infer),
) )
.route( .route(
"/v2/models/:model_name/versions/:model_version", "/v2/models/:model_name/versions/:model_version",
get(get_v2_models_model_name_versions_model_version), get(kserve_model_metadata),
) )
.route("/v2/health/ready", get(get_v2_health_ready)) .route("/v2/health/ready", get(kserve_health_ready))
.route("/v2/health/live", get(get_v2_health_live)) .route("/v2/health/live", get(kserve_health_live))
.route("/v2", get(get_v2)) .route("/v2", get(kerve_server_metadata))
.route( .route(
"/v2/models/:model_name/versions/:model_version/ready", "/v2/models/:model_name/versions/:model_version/ready",
get(get_v2_models_model_name_versions_model_version_ready), get(kserve_model_metadata_ready),
); );
} }