feat: improve syntax and conditional openapi docs

This commit is contained in:
drbh 2024-02-19 16:16:41 +00:00
parent 384b4eaec4
commit 111a3f6809
2 changed files with 54 additions and 22 deletions

View File

@ -21,16 +21,17 @@ pub(crate) type GenerateStreamResponse = (
);
#[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct Instance {
pub(crate) struct VertexInstance {
#[schema(example = "What is Deep Learning?")]
pub inputs: String,
#[schema(nullable = true, default = "null", example = "null")]
pub parameters: Option<GenerateParameters>,
}
#[derive(Deserialize, ToSchema)]
pub(crate) struct VertexRequest {
pub instances: Vec<Instance>,
#[allow(dead_code)]
pub parameters: Option<GenerateParameters>,
#[serde(rename = "instances")]
pub instances: Vec<VertexInstance>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
@ -88,7 +89,7 @@ mod json_object_or_string_to_string {
}
}
#[derive(Clone, Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize, ToSchema)]
#[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType {
#[serde(

View File

@ -5,7 +5,7 @@ use crate::validation::ValidationError;
use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
};
@ -699,7 +699,7 @@ async fn chat_completions(
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v1/endpoints",
path = "/vertex",
request_body = VertexRequest,
responses(
(status = 200, description = "Generated Text", body = VertexResponse),
@ -726,6 +726,7 @@ async fn chat_completions(
)]
async fn vertex_compatibility(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count");
@ -759,7 +760,11 @@ async fn vertex_compatibility(
};
async {
generate(Extension(infer.clone()), Json(generate_request))
generate(
Extension(infer.clone()),
Extension(compute_type.clone()),
Json(generate_request),
)
.await
.map(|(_, Json(generation))| generation.generated_text)
.map_err(|_| {
@ -906,6 +911,7 @@ pub async fn run(
StreamResponse,
StreamDetails,
ErrorResponse,
GrammarType,
)
),
tags(
@ -1030,8 +1036,30 @@ pub async fn run(
docker_label: option_env!("DOCKER_LABEL"),
};
// Define VertextApiDoc conditionally only if the "google" feature is enabled
#[cfg(feature = "google")]
#[derive(OpenApi)]
#[openapi(
paths(vertex_compatibility),
components(schemas(VertexInstance, VertexRequest, VertexResponse))
)]
struct VertextApiDoc;
let doc = {
// avoid `mut` if possible
#[cfg(feature = "google")]
{
// limiting mutability to the smallest scope necessary
let mut doc = doc;
doc.merge(VertextApiDoc::openapi());
doc
}
#[cfg(not(feature = "google"))]
ApiDoc::openapi()
};
// Configure Swagger UI
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi());
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
// Define base and health routes
let base_routes = Router::new()
@ -1063,9 +1091,12 @@ pub async fn run(
.merge(base_routes)
.merge(aws_sagemaker_route);
if cfg!(feature = "google") {
#[cfg(feature = "google")]
{
tracing::info!("Built with `google` feature");
tracing::info!("Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected.");
tracing::info!(
"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
);
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
app = app.route(&env_predict_route, post(vertex_compatibility));
}