mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve syntax and conditional openapi docs
This commit is contained in:
parent
384b4eaec4
commit
111a3f6809
@ -21,16 +21,17 @@ pub(crate) type GenerateStreamResponse = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
pub(crate) struct Instance {
|
pub(crate) struct VertexInstance {
|
||||||
|
#[schema(example = "What is Deep Learning?")]
|
||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub parameters: Option<GenerateParameters>,
|
pub parameters: Option<GenerateParameters>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, ToSchema)]
|
#[derive(Deserialize, ToSchema)]
|
||||||
pub(crate) struct VertexRequest {
|
pub(crate) struct VertexRequest {
|
||||||
pub instances: Vec<Instance>,
|
#[serde(rename = "instances")]
|
||||||
#[allow(dead_code)]
|
pub instances: Vec<VertexInstance>,
|
||||||
pub parameters: Option<GenerateParameters>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
@ -88,7 +89,7 @@ mod json_object_or_string_to_string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
#[serde(tag = "type", content = "value")]
|
#[serde(tag = "type", content = "value")]
|
||||||
pub(crate) enum GrammarType {
|
pub(crate) enum GrammarType {
|
||||||
#[serde(
|
#[serde(
|
||||||
|
@ -5,7 +5,7 @@ use crate::validation::ValidationError;
|
|||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||||
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
||||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||||
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
|
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
@ -699,7 +699,7 @@ async fn chat_completions(
|
|||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/v1/endpoints",
|
path = "/vertex",
|
||||||
request_body = VertexRequest,
|
request_body = VertexRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text", body = VertexResponse),
|
(status = 200, description = "Generated Text", body = VertexResponse),
|
||||||
@ -726,6 +726,7 @@ async fn chat_completions(
|
|||||||
)]
|
)]
|
||||||
async fn vertex_compatibility(
|
async fn vertex_compatibility(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Json(req): Json<VertexRequest>,
|
Json(req): Json<VertexRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
@ -759,18 +760,22 @@ async fn vertex_compatibility(
|
|||||||
};
|
};
|
||||||
|
|
||||||
async {
|
async {
|
||||||
generate(Extension(infer.clone()), Json(generate_request))
|
generate(
|
||||||
.await
|
Extension(infer.clone()),
|
||||||
.map(|(_, Json(generation))| generation.generated_text)
|
Extension(compute_type.clone()),
|
||||||
.map_err(|_| {
|
Json(generate_request),
|
||||||
(
|
)
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
.await
|
||||||
Json(ErrorResponse {
|
.map(|(_, Json(generation))| generation.generated_text)
|
||||||
error: "Incomplete generation".into(),
|
.map_err(|_| {
|
||||||
error_type: "Incomplete generation".into(),
|
(
|
||||||
}),
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
)
|
Json(ErrorResponse {
|
||||||
})
|
error: "Incomplete generation".into(),
|
||||||
|
error_type: "Incomplete generation".into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect::<FuturesUnordered<_>>()
|
.collect::<FuturesUnordered<_>>()
|
||||||
@ -906,6 +911,7 @@ pub async fn run(
|
|||||||
StreamResponse,
|
StreamResponse,
|
||||||
StreamDetails,
|
StreamDetails,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
|
GrammarType,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
tags(
|
tags(
|
||||||
@ -1030,8 +1036,30 @@ pub async fn run(
|
|||||||
docker_label: option_env!("DOCKER_LABEL"),
|
docker_label: option_env!("DOCKER_LABEL"),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Define VertextApiDoc conditionally only if the "google" feature is enabled
|
||||||
|
#[cfg(feature = "google")]
|
||||||
|
#[derive(OpenApi)]
|
||||||
|
#[openapi(
|
||||||
|
paths(vertex_compatibility),
|
||||||
|
components(schemas(VertexInstance, VertexRequest, VertexResponse))
|
||||||
|
)]
|
||||||
|
struct VertextApiDoc;
|
||||||
|
|
||||||
|
let doc = {
|
||||||
|
// avoid `mut` if possible
|
||||||
|
#[cfg(feature = "google")]
|
||||||
|
{
|
||||||
|
// limiting mutability to the smallest scope necessary
|
||||||
|
let mut doc = doc;
|
||||||
|
doc.merge(VertextApiDoc::openapi());
|
||||||
|
doc
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "google"))]
|
||||||
|
ApiDoc::openapi()
|
||||||
|
};
|
||||||
|
|
||||||
// Configure Swagger UI
|
// Configure Swagger UI
|
||||||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi());
|
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
|
||||||
|
|
||||||
// Define base and health routes
|
// Define base and health routes
|
||||||
let base_routes = Router::new()
|
let base_routes = Router::new()
|
||||||
@ -1063,9 +1091,12 @@ pub async fn run(
|
|||||||
.merge(base_routes)
|
.merge(base_routes)
|
||||||
.merge(aws_sagemaker_route);
|
.merge(aws_sagemaker_route);
|
||||||
|
|
||||||
if cfg!(feature = "google") {
|
#[cfg(feature = "google")]
|
||||||
|
{
|
||||||
tracing::info!("Built with `google` feature");
|
tracing::info!("Built with `google` feature");
|
||||||
tracing::info!("Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected.");
|
tracing::info!(
|
||||||
|
"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
|
||||||
|
);
|
||||||
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
|
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
|
||||||
app = app.route(&env_predict_route, post(vertex_compatibility));
|
app = app.route(&env_predict_route, post(vertex_compatibility));
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user