mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support vertex api
This commit is contained in:
parent
4139054b82
commit
f4fd89b224
@ -20,6 +20,24 @@ pub(crate) type GenerateStreamResponse = (
|
|||||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct Instance {
|
||||||
|
pub inputs: String,
|
||||||
|
pub parameters: Option<GenerateParameters>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct VertexRequest {
|
||||||
|
pub instances: Vec<Instance>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub parameters: Option<GenerateParameters>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
pub(crate) struct VertexResponse {
|
||||||
|
pub predictions: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Hub type
|
/// Hub type
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct HubModelInfo {
|
pub struct HubModelInfo {
|
||||||
@ -153,7 +171,7 @@ pub struct Info {
|
|||||||
pub docker_label: Option<&'static str>,
|
pub docker_label: Option<&'static str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
||||||
pub(crate) struct GenerateParameters {
|
pub(crate) struct GenerateParameters {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
||||||
|
@ -21,6 +21,36 @@ use tracing_subscriber::layer::SubscriberExt;
|
|||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
use tracing_subscriber::{EnvFilter, Layer};
|
use tracing_subscriber::{EnvFilter, Layer};
|
||||||
|
|
||||||
|
#[allow(dead_code)] // many of the fields are not used
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct VertexAIConfig {
|
||||||
|
aip_http_port: u16,
|
||||||
|
aip_predict_route: String,
|
||||||
|
aip_health_route: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VertexAIConfig {
|
||||||
|
fn new(aip_http_port: u16, aip_predict_route: String, aip_health_route: String) -> Self {
|
||||||
|
Self {
|
||||||
|
aip_http_port,
|
||||||
|
aip_predict_route,
|
||||||
|
aip_health_route,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn to_env(&self) {
|
||||||
|
// NOTE: this will only set the values for this process
|
||||||
|
// NOTE: child processes cannot set env vars for their parents
|
||||||
|
// TODO: find a way to set the values for the whole system
|
||||||
|
// - maybe write to a file
|
||||||
|
// - maybe use a shell script to set the values
|
||||||
|
// - maybe these values are set upstream (before this process is started)
|
||||||
|
// - if set upstream maybe we read in; if we need them?
|
||||||
|
std::env::set_var("AIP_HTTP_PORT", self.aip_http_port.to_string());
|
||||||
|
std::env::set_var("AIP_PREDICT_ROUTE", self.aip_predict_route.clone());
|
||||||
|
std::env::set_var("AIP_HEALTH_ROUTE", self.aip_health_route.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
@ -113,6 +143,11 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
|
// Set Vertex AI config and update the env
|
||||||
|
let vertex_ai_config =
|
||||||
|
VertexAIConfig::new(args.port, "/vertex".to_string(), "/health".to_string());
|
||||||
|
vertex_ai_config.to_env();
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
init_logging(otlp_endpoint, json_output);
|
init_logging(otlp_endpoint, json_output);
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ use crate::{
|
|||||||
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
||||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
||||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||||
StreamResponse, Token, TokenizeResponse, Validation,
|
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
@ -16,8 +16,10 @@ use axum::response::{IntoResponse, Response};
|
|||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{http, Json, Router};
|
use axum::{http, Json, Router};
|
||||||
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
||||||
|
use futures::stream::FuturesUnordered;
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
|
use futures::TryStreamExt;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
@ -693,6 +695,92 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate tokens from Vertex request
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v1/endpoints",
|
||||||
|
request_body = VertexRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Text", body = VertexResponse),
|
||||||
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Model is overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Input validation error"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Incomplete generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
total_time,
|
||||||
|
validation_time,
|
||||||
|
queue_time,
|
||||||
|
inference_time,
|
||||||
|
time_per_token,
|
||||||
|
seed,
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn vertex_compatibility(
|
||||||
|
Extension(infer): Extension<Infer>,
|
||||||
|
Json(req): Json<VertexRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
|
// check that theres at least one instance
|
||||||
|
if req.instances.is_empty() {
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Input validation error".to_string(),
|
||||||
|
error_type: "Input validation error".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process all instances
|
||||||
|
let predictions = req
|
||||||
|
.instances
|
||||||
|
.iter()
|
||||||
|
.map(|instance| {
|
||||||
|
let generate_request = GenerateRequest {
|
||||||
|
inputs: instance.inputs.clone(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
do_sample: true,
|
||||||
|
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||||
|
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: true,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
async {
|
||||||
|
generate(Extension(infer.clone()), Json(generate_request))
|
||||||
|
.await
|
||||||
|
.map(|(_, Json(generation))| generation.generated_text)
|
||||||
|
.map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Incomplete generation".into(),
|
||||||
|
error_type: "Incomplete generation".into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<FuturesUnordered<_>>()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let response = VertexResponse { predictions };
|
||||||
|
Ok((HeaderMap::new(), Json(response)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
/// Tokenize inputs
|
/// Tokenize inputs
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
@ -953,6 +1041,7 @@ pub async fn run(
|
|||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
|
.route("/vertex", post(vertex_compatibility))
|
||||||
.route("/tokenize", post(tokenize))
|
.route("/tokenize", post(tokenize))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
|
Loading…
Reference in New Issue
Block a user