diff --git a/router/src/lib.rs b/router/src/lib.rs index ab922bf0..b7285e65 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -21,16 +21,17 @@ pub(crate) type GenerateStreamResponse = ( ); #[derive(Clone, Deserialize, ToSchema)] -pub(crate) struct Instance { +pub(crate) struct VertexInstance { + #[schema(example = "What is Deep Learning?")] pub inputs: String, + #[schema(nullable = true, default = "null", example = "null")] pub parameters: Option, } #[derive(Deserialize, ToSchema)] pub(crate) struct VertexRequest { - pub instances: Vec, - #[allow(dead_code)] - pub parameters: Option, + #[serde(rename = "instances")] + pub instances: Vec, } #[derive(Clone, Deserialize, ToSchema, Serialize)] @@ -88,7 +89,7 @@ mod json_object_or_string_to_string { } } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, ToSchema)] #[serde(tag = "type", content = "value")] pub(crate) enum GrammarType { #[serde( diff --git a/router/src/server.rs b/router/src/server.rs index 72e0eea8..140fb014 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -5,7 +5,7 @@ use crate::validation::ValidationError; use crate::{ BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, - FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, + FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse, }; @@ -699,7 +699,7 @@ async fn chat_completions( #[utoipa::path( post, tag = "Text Generation Inference", - path = "/v1/endpoints", + path = "/vertex", request_body = VertexRequest, responses( (status = 200, description = "Generated Text", body = VertexResponse), @@ -726,6 +726,7 @@ async fn chat_completions( )] async fn vertex_compatibility( Extension(infer): Extension, + Extension(compute_type): Extension, Json(req): Json, ) -> Result)> { metrics::increment_counter!("tgi_request_count"); @@ -759,18 +760,22 @@ async fn vertex_compatibility( }; async { - generate(Extension(infer.clone()), Json(generate_request)) - .await - .map(|(_, Json(generation))| generation.generated_text) - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Incomplete generation".into(), - error_type: "Incomplete generation".into(), - }), - ) - }) + generate( + Extension(infer.clone()), + Extension(compute_type.clone()), + Json(generate_request), + ) + .await + .map(|(_, Json(generation))| generation.generated_text) + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".into(), + error_type: "Incomplete generation".into(), + }), + ) + }) } }) .collect::>() @@ -906,6 +911,7 @@ pub async fn run( StreamResponse, StreamDetails, ErrorResponse, + GrammarType, ) ), tags( @@ -1030,8 +1036,30 @@ pub async fn run( docker_label: option_env!("DOCKER_LABEL"), }; + // Define VertextApiDoc conditionally only if the "google" feature is enabled + #[cfg(feature = "google")] + #[derive(OpenApi)] + #[openapi( + paths(vertex_compatibility), + components(schemas(VertexInstance, VertexRequest, VertexResponse)) + )] + struct VertextApiDoc; + + let doc = { + // avoid `mut` if possible + #[cfg(feature = "google")] + { + // limiting mutability to the smallest scope necessary + let mut doc = doc; + doc.merge(VertextApiDoc::openapi()); + doc + } + #[cfg(not(feature = "google"))] + ApiDoc::openapi() + }; + // Configure Swagger UI - let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()); + let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); // Define base and health routes let base_routes = Router::new() @@ -1063,9 +1091,12 @@ pub async fn run( .merge(base_routes) .merge(aws_sagemaker_route); - if cfg!(feature = "google") { + #[cfg(feature = "google")] + { tracing::info!("Built with `google` feature"); - tracing::info!("Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."); + tracing::info!( + "Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected." + ); if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") { app = app.route(&env_predict_route, post(vertex_compatibility)); }