mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support vertex api
This commit is contained in:
parent
4139054b82
commit
f4fd89b224
@ -20,6 +20,24 @@ pub(crate) type GenerateStreamResponse = (
|
||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||
);
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
pub(crate) struct Instance {
|
||||
pub inputs: String,
|
||||
pub parameters: Option<GenerateParameters>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, ToSchema)]
|
||||
pub(crate) struct VertexRequest {
|
||||
pub instances: Vec<Instance>,
|
||||
#[allow(dead_code)]
|
||||
pub parameters: Option<GenerateParameters>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
pub(crate) struct VertexResponse {
|
||||
pub predictions: Vec<String>,
|
||||
}
|
||||
|
||||
/// Hub type
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct HubModelInfo {
|
||||
@ -153,7 +171,7 @@ pub struct Info {
|
||||
pub docker_label: Option<&'static str>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
||||
pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
||||
|
@ -21,6 +21,36 @@ use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::{EnvFilter, Layer};
|
||||
|
||||
#[allow(dead_code)] // many of the fields are not used
|
||||
#[derive(Debug)]
|
||||
struct VertexAIConfig {
|
||||
aip_http_port: u16,
|
||||
aip_predict_route: String,
|
||||
aip_health_route: String,
|
||||
}
|
||||
|
||||
impl VertexAIConfig {
|
||||
fn new(aip_http_port: u16, aip_predict_route: String, aip_health_route: String) -> Self {
|
||||
Self {
|
||||
aip_http_port,
|
||||
aip_predict_route,
|
||||
aip_health_route,
|
||||
}
|
||||
}
|
||||
fn to_env(&self) {
|
||||
// NOTE: this will only set the values for this process
|
||||
// NOTE: child processes cannot set env vars for their parents
|
||||
// TODO: find a way to set the values for the whole system
|
||||
// - maybe write to a file
|
||||
// - maybe use a shell script to set the values
|
||||
// - maybe these values are set upstream (before this process is started)
|
||||
// - if set upstream maybe we read in; if we need them?
|
||||
std::env::set_var("AIP_HTTP_PORT", self.aip_http_port.to_string());
|
||||
std::env::set_var("AIP_PREDICT_ROUTE", self.aip_predict_route.clone());
|
||||
std::env::set_var("AIP_HEALTH_ROUTE", self.aip_health_route.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
@ -113,6 +143,11 @@ async fn main() -> Result<(), RouterError> {
|
||||
disable_grammar_support,
|
||||
} = args;
|
||||
|
||||
// Set Vertex AI config and update the env
|
||||
let vertex_ai_config =
|
||||
VertexAIConfig::new(args.port, "/vertex".to_string(), "/health".to_string());
|
||||
vertex_ai_config.to_env();
|
||||
|
||||
// Launch Tokio runtime
|
||||
init_logging(otlp_endpoint, json_output);
|
||||
|
||||
|
@ -7,7 +7,7 @@ use crate::{
|
||||
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Validation,
|
||||
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
@ -16,8 +16,10 @@ use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::{get, post};
|
||||
use axum::{http, Json, Router};
|
||||
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::Stream;
|
||||
use futures::TryStreamExt;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
@ -693,6 +695,92 @@ async fn chat_completions(
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate tokens from Vertex request
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v1/endpoints",
|
||||
request_body = VertexRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = VertexResponse),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json ! ({"error": "Model is overloaded"})),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Input validation error"})),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json ! ({"error": "Incomplete generation"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
total_time,
|
||||
validation_time,
|
||||
queue_time,
|
||||
inference_time,
|
||||
time_per_token,
|
||||
seed,
|
||||
)
|
||||
)]
|
||||
async fn vertex_compatibility(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Json(req): Json<VertexRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
// check that theres at least one instance
|
||||
if req.instances.is_empty() {
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Input validation error".to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
// Process all instances
|
||||
let predictions = req
|
||||
.instances
|
||||
.iter()
|
||||
.map(|instance| {
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: instance.inputs.clone(),
|
||||
parameters: GenerateParameters {
|
||||
do_sample: true,
|
||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
||||
details: true,
|
||||
decoder_input_details: true,
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
async {
|
||||
generate(Extension(infer.clone()), Json(generate_request))
|
||||
.await
|
||||
.map(|(_, Json(generation))| generation.generated_text)
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect::<FuturesUnordered<_>>()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await?;
|
||||
|
||||
let response = VertexResponse { predictions };
|
||||
Ok((HeaderMap::new(), Json(response)).into_response())
|
||||
}
|
||||
|
||||
/// Tokenize inputs
|
||||
#[utoipa::path(
|
||||
post,
|
||||
@ -953,6 +1041,7 @@ pub async fn run(
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/vertex", post(vertex_compatibility))
|
||||
.route("/tokenize", post(tokenize))
|
||||
.route("/health", get(health))
|
||||
.route("/ping", get(health))
|
||||
|
Loading…
Reference in New Issue
Block a user