feat: support vertex api

This commit is contained in:
drbh 2024-01-16 15:05:44 -05:00
parent 4139054b82
commit f4fd89b224
3 changed files with 144 additions and 2 deletions

View File

@ -20,6 +20,24 @@ pub(crate) type GenerateStreamResponse = (
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
); );
#[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct Instance {
pub inputs: String,
pub parameters: Option<GenerateParameters>,
}
#[derive(Deserialize, ToSchema)]
pub(crate) struct VertexRequest {
pub instances: Vec<Instance>,
#[allow(dead_code)]
pub parameters: Option<GenerateParameters>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct VertexResponse {
pub predictions: Vec<String>,
}
/// Hub type /// Hub type
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo { pub struct HubModelInfo {
@ -153,7 +171,7 @@ pub struct Info {
pub docker_label: Option<&'static str>, pub docker_label: Option<&'static str>,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema, Default)]
pub(crate) struct GenerateParameters { pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]

View File

@ -21,6 +21,36 @@ use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer}; use tracing_subscriber::{EnvFilter, Layer};
#[allow(dead_code)] // many of the fields are not used
#[derive(Debug)]
struct VertexAIConfig {
aip_http_port: u16,
aip_predict_route: String,
aip_health_route: String,
}
impl VertexAIConfig {
fn new(aip_http_port: u16, aip_predict_route: String, aip_health_route: String) -> Self {
Self {
aip_http_port,
aip_predict_route,
aip_health_route,
}
}
fn to_env(&self) {
// NOTE: this will only set the values for this process
// NOTE: child processes cannot set env vars for their parents
// TODO: find a way to set the values for the whole system
// - maybe write to a file
// - maybe use a shell script to set the values
// - maybe these values are set upstream (before this process is started)
// - if set upstream maybe we read in; if we need them?
std::env::set_var("AIP_HTTP_PORT", self.aip_http_port.to_string());
std::env::set_var("AIP_PREDICT_ROUTE", self.aip_predict_route.clone());
std::env::set_var("AIP_HEALTH_ROUTE", self.aip_health_route.clone());
}
}
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)] #[clap(author, version, about, long_about = None)]
@ -113,6 +143,11 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support, disable_grammar_support,
} = args; } = args;
// Set Vertex AI config and update the env
let vertex_ai_config =
VertexAIConfig::new(args.port, "/vertex".to_string(), "/health".to_string());
vertex_ai_config.to_env();
// Launch Tokio runtime // Launch Tokio runtime
init_logging(otlp_endpoint, json_output); init_logging(otlp_endpoint, json_output);

View File

@ -7,7 +7,7 @@ use crate::{
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Validation, StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -16,8 +16,10 @@ use axum::response::{IntoResponse, Response};
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{http, Json, Router}; use axum::{http, Json, Router};
use axum_tracing_opentelemetry::middleware::OtelAxumLayer; use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
use futures::stream::FuturesUnordered;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use futures::Stream; use futures::Stream;
use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
@ -693,6 +695,92 @@ async fn chat_completions(
} }
} }
/// Generate tokens from Vertex request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v1/endpoints",
request_body = VertexRequest,
responses(
(status = 200, description = "Generated Text", body = VertexResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(
skip_all,
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
)]
async fn vertex_compatibility(
Extension(infer): Extension<Infer>,
Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count");
// check that theres at least one instance
if req.instances.is_empty() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
));
}
// Process all instances
let predictions = req
.instances
.iter()
.map(|instance| {
let generate_request = GenerateRequest {
inputs: instance.inputs.clone(),
parameters: GenerateParameters {
do_sample: true,
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
seed: instance.parameters.as_ref().and_then(|p| p.seed),
details: true,
decoder_input_details: true,
..Default::default()
},
};
async {
generate(Extension(infer.clone()), Json(generate_request))
.await
.map(|(_, Json(generation))| generation.generated_text)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
}
})
.collect::<FuturesUnordered<_>>()
.try_collect::<Vec<_>>()
.await?;
let response = VertexResponse { predictions };
Ok((HeaderMap::new(), Json(response)).into_response())
}
/// Tokenize inputs /// Tokenize inputs
#[utoipa::path( #[utoipa::path(
post, post,
@ -953,6 +1041,7 @@ pub async fn run(
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
.route("/vertex", post(vertex_compatibility))
.route("/tokenize", post(tokenize)) .route("/tokenize", post(tokenize))
.route("/health", get(health)) .route("/health", get(health))
.route("/ping", get(health)) .route("/ping", get(health))