mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve syntax and conditional openapi docs
This commit is contained in:
parent
384b4eaec4
commit
111a3f6809
@ -21,16 +21,17 @@ pub(crate) type GenerateStreamResponse = (
|
||||
);
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
pub(crate) struct Instance {
|
||||
pub(crate) struct VertexInstance {
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub inputs: String,
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub parameters: Option<GenerateParameters>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, ToSchema)]
|
||||
pub(crate) struct VertexRequest {
|
||||
pub instances: Vec<Instance>,
|
||||
#[allow(dead_code)]
|
||||
pub parameters: Option<GenerateParameters>,
|
||||
#[serde(rename = "instances")]
|
||||
pub instances: Vec<VertexInstance>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
@ -88,7 +89,7 @@ mod json_object_or_string_to_string {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
#[serde(tag = "type", content = "value")]
|
||||
pub(crate) enum GrammarType {
|
||||
#[serde(
|
||||
|
@ -5,7 +5,7 @@ use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
|
||||
};
|
||||
@ -699,7 +699,7 @@ async fn chat_completions(
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v1/endpoints",
|
||||
path = "/vertex",
|
||||
request_body = VertexRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = VertexResponse),
|
||||
@ -726,6 +726,7 @@ async fn chat_completions(
|
||||
)]
|
||||
async fn vertex_compatibility(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Json(req): Json<VertexRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
@ -759,7 +760,11 @@ async fn vertex_compatibility(
|
||||
};
|
||||
|
||||
async {
|
||||
generate(Extension(infer.clone()), Json(generate_request))
|
||||
generate(
|
||||
Extension(infer.clone()),
|
||||
Extension(compute_type.clone()),
|
||||
Json(generate_request),
|
||||
)
|
||||
.await
|
||||
.map(|(_, Json(generation))| generation.generated_text)
|
||||
.map_err(|_| {
|
||||
@ -906,6 +911,7 @@ pub async fn run(
|
||||
StreamResponse,
|
||||
StreamDetails,
|
||||
ErrorResponse,
|
||||
GrammarType,
|
||||
)
|
||||
),
|
||||
tags(
|
||||
@ -1030,8 +1036,30 @@ pub async fn run(
|
||||
docker_label: option_env!("DOCKER_LABEL"),
|
||||
};
|
||||
|
||||
// Define VertextApiDoc conditionally only if the "google" feature is enabled
|
||||
#[cfg(feature = "google")]
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
paths(vertex_compatibility),
|
||||
components(schemas(VertexInstance, VertexRequest, VertexResponse))
|
||||
)]
|
||||
struct VertextApiDoc;
|
||||
|
||||
let doc = {
|
||||
// avoid `mut` if possible
|
||||
#[cfg(feature = "google")]
|
||||
{
|
||||
// limiting mutability to the smallest scope necessary
|
||||
let mut doc = doc;
|
||||
doc.merge(VertextApiDoc::openapi());
|
||||
doc
|
||||
}
|
||||
#[cfg(not(feature = "google"))]
|
||||
ApiDoc::openapi()
|
||||
};
|
||||
|
||||
// Configure Swagger UI
|
||||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi());
|
||||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
|
||||
|
||||
// Define base and health routes
|
||||
let base_routes = Router::new()
|
||||
@ -1063,9 +1091,12 @@ pub async fn run(
|
||||
.merge(base_routes)
|
||||
.merge(aws_sagemaker_route);
|
||||
|
||||
if cfg!(feature = "google") {
|
||||
#[cfg(feature = "google")]
|
||||
{
|
||||
tracing::info!("Built with `google` feature");
|
||||
tracing::info!("Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected.");
|
||||
tracing::info!(
|
||||
"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
|
||||
);
|
||||
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
|
||||
app = app.route(&env_predict_route, post(vertex_compatibility));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user