From f4fd89b224a9ed2811cfa163cd9e630fd586f7ad Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 16 Jan 2024 15:05:44 -0500 Subject: [PATCH] feat: support vertex api --- router/src/lib.rs | 20 +++++++++- router/src/main.rs | 35 +++++++++++++++++ router/src/server.rs | 91 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 144 insertions(+), 2 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 8c7ca74b..ab922bf0 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -20,6 +20,24 @@ pub(crate) type GenerateStreamResponse = ( UnboundedReceiverStream>, ); +#[derive(Clone, Deserialize, ToSchema)] +pub(crate) struct Instance { + pub inputs: String, + pub parameters: Option, +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct VertexRequest { + pub instances: Vec, + #[allow(dead_code)] + pub parameters: Option, +} + +#[derive(Clone, Deserialize, ToSchema, Serialize)] +pub(crate) struct VertexResponse { + pub predictions: Vec, +} + /// Hub type #[derive(Clone, Debug, Deserialize)] pub struct HubModelInfo { @@ -153,7 +171,7 @@ pub struct Info { pub docker_label: Option<&'static str>, } -#[derive(Clone, Debug, Deserialize, ToSchema)] +#[derive(Clone, Debug, Deserialize, ToSchema, Default)] pub(crate) struct GenerateParameters { #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] diff --git a/router/src/main.rs b/router/src/main.rs index 457bca8e..65802d5c 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -21,6 +21,36 @@ use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{EnvFilter, Layer}; +#[allow(dead_code)] // many of the fields are not used +#[derive(Debug)] +struct VertexAIConfig { + aip_http_port: u16, + aip_predict_route: String, + aip_health_route: String, +} + +impl VertexAIConfig { + fn new(aip_http_port: u16, aip_predict_route: String, aip_health_route: String) -> Self { + Self { + aip_http_port, + aip_predict_route, + aip_health_route, + } + } + fn to_env(&self) { + // NOTE: this will only set the values for this process + // NOTE: child processes cannot set env vars for their parents + // TODO: find a way to set the values for the whole system + // - maybe write to a file + // - maybe use a shell script to set the values + // - maybe these values are set upstream (before this process is started) + // - if set upstream maybe we read in; if we need them? + std::env::set_var("AIP_HTTP_PORT", self.aip_http_port.to_string()); + std::env::set_var("AIP_PREDICT_ROUTE", self.aip_predict_route.clone()); + std::env::set_var("AIP_HEALTH_ROUTE", self.aip_health_route.clone()); + } +} + /// App Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -113,6 +143,11 @@ async fn main() -> Result<(), RouterError> { disable_grammar_support, } = args; + // Set Vertex AI config and update the env + let vertex_ai_config = + VertexAIConfig::new(args.port, "/vertex".to_string(), "/health".to_string()); + vertex_ai_config.to_env(); + // Launch Tokio runtime init_logging(otlp_endpoint, json_output); diff --git a/router/src/server.rs b/router/src/server.rs index 0fc76916..2846142f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -7,7 +7,7 @@ use crate::{ ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, - StreamResponse, Token, TokenizeResponse, Validation, + StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -16,8 +16,10 @@ use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; +use futures::stream::FuturesUnordered; use futures::stream::StreamExt; use futures::Stream; +use futures::TryStreamExt; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; use std::net::SocketAddr; @@ -693,6 +695,92 @@ async fn chat_completions( } } +/// Generate tokens from Vertex request +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v1/endpoints", + request_body = VertexRequest, + responses( + (status = 200, description = "Generated Text", body = VertexResponse), + (status = 424, description = "Generation Error", body = ErrorResponse, + example = json ! ({"error": "Request failed during generation"})), + (status = 429, description = "Model is overloaded", body = ErrorResponse, + example = json ! ({"error": "Model is overloaded"})), + (status = 422, description = "Input validation error", body = ErrorResponse, + example = json ! ({"error": "Input validation error"})), + (status = 500, description = "Incomplete generation", body = ErrorResponse, + example = json ! ({"error": "Incomplete generation"})), + ) + )] +#[instrument( + skip_all, + fields( + total_time, + validation_time, + queue_time, + inference_time, + time_per_token, + seed, + ) +)] +async fn vertex_compatibility( + Extension(infer): Extension, + Json(req): Json, +) -> Result)> { + metrics::increment_counter!("tgi_request_count"); + + // check that theres at least one instance + if req.instances.is_empty() { + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Input validation error".to_string(), + error_type: "Input validation error".to_string(), + }), + )); + } + + // Process all instances + let predictions = req + .instances + .iter() + .map(|instance| { + let generate_request = GenerateRequest { + inputs: instance.inputs.clone(), + parameters: GenerateParameters { + do_sample: true, + max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), + seed: instance.parameters.as_ref().and_then(|p| p.seed), + details: true, + decoder_input_details: true, + ..Default::default() + }, + }; + + async { + generate(Extension(infer.clone()), Json(generate_request)) + .await + .map(|(_, Json(generation))| generation.generated_text) + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".into(), + error_type: "Incomplete generation".into(), + }), + ) + }) + } + }) + .collect::>() + .try_collect::>() + .await?; + + let response = VertexResponse { predictions }; + Ok((HeaderMap::new(), Json(response)).into_response()) +} + /// Tokenize inputs #[utoipa::path( post, @@ -953,6 +1041,7 @@ pub async fn run( .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) + .route("/vertex", post(vertex_compatibility)) .route("/tokenize", post(tokenize)) .route("/health", get(health)) .route("/ping", get(health))