mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
As spotted by @philschmid, the payload was compliant with Vertex AI, but just partially, since ideally the most compliant version would be with the generation kwargs flattened to be on the same level as the `messages`; meaning that Vertex AI would still expect a list of instances, but each instance would be an OpenAI-compatible instance, which is more clear; and more aligned with the SageMaker integration too, so kudos to him for spotting that; and sorry from my end for any inconvenience @Narsil.
185 lines
6.0 KiB
Rust
185 lines
6.0 KiB
Rust
use crate::infer::Infer;
|
|
use crate::server::{generate_internal, ComputeType};
|
|
use crate::{ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest};
|
|
use axum::extract::Extension;
|
|
use axum::http::{HeaderMap, StatusCode};
|
|
use axum::response::{IntoResponse, Response};
|
|
use axum::Json;
|
|
use serde::{Deserialize, Serialize};
|
|
use tracing::instrument;
|
|
use utoipa::ToSchema;
|
|
|
|
#[derive(Clone, Deserialize, ToSchema)]
|
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
|
pub(crate) struct GenerateVertexInstance {
|
|
#[schema(example = "What is Deep Learning?")]
|
|
pub inputs: String,
|
|
#[schema(nullable = true, default = "null", example = "null")]
|
|
pub parameters: Option<GenerateParameters>,
|
|
}
|
|
|
|
#[derive(Clone, Deserialize, ToSchema)]
|
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
|
#[serde(untagged)]
|
|
pub(crate) enum VertexInstance {
|
|
Generate(GenerateVertexInstance),
|
|
Chat(ChatRequest),
|
|
}
|
|
|
|
#[derive(Deserialize, ToSchema)]
|
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
|
pub(crate) struct VertexRequest {
|
|
#[serde(rename = "instances")]
|
|
pub instances: Vec<VertexInstance>,
|
|
}
|
|
|
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
|
pub(crate) struct VertexResponse {
|
|
pub predictions: Vec<String>,
|
|
}
|
|
|
|
/// Generate tokens from Vertex request
|
|
#[utoipa::path(
|
|
post,
|
|
tag = "Text Generation Inference",
|
|
path = "/vertex",
|
|
request_body = VertexRequest,
|
|
responses(
|
|
(status = 200, description = "Generated Text", body = VertexResponse),
|
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
|
example = json ! ({"error": "Request failed during generation"})),
|
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
|
example = json ! ({"error": "Model is overloaded"})),
|
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
|
example = json ! ({"error": "Input validation error"})),
|
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
|
example = json ! ({"error": "Incomplete generation"})),
|
|
)
|
|
)]
|
|
#[instrument(
|
|
skip_all,
|
|
fields(
|
|
total_time,
|
|
validation_time,
|
|
queue_time,
|
|
inference_time,
|
|
time_per_token,
|
|
seed,
|
|
)
|
|
)]
|
|
pub(crate) async fn vertex_compatibility(
|
|
Extension(infer): Extension<Infer>,
|
|
Extension(compute_type): Extension<ComputeType>,
|
|
Json(req): Json<VertexRequest>,
|
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
|
let span = tracing::Span::current();
|
|
metrics::counter!("tgi_request_count").increment(1);
|
|
|
|
// check that theres at least one instance
|
|
if req.instances.is_empty() {
|
|
return Err((
|
|
StatusCode::UNPROCESSABLE_ENTITY,
|
|
Json(ErrorResponse {
|
|
error: "Input validation error".to_string(),
|
|
error_type: "Input validation error".to_string(),
|
|
}),
|
|
));
|
|
}
|
|
|
|
// Prepare futures for all instances
|
|
let mut futures = Vec::with_capacity(req.instances.len());
|
|
|
|
for instance in req.instances.into_iter() {
|
|
let generate_request = match instance {
|
|
VertexInstance::Generate(instance) => GenerateRequest {
|
|
inputs: instance.inputs.clone(),
|
|
add_special_tokens: true,
|
|
parameters: GenerateParameters {
|
|
do_sample: true,
|
|
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
|
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
|
details: true,
|
|
decoder_input_details: true,
|
|
..Default::default()
|
|
},
|
|
},
|
|
VertexInstance::Chat(instance) => {
|
|
let (generate_request, _using_tools): (GenerateRequest, bool) =
|
|
instance.try_into_generate(&infer)?;
|
|
generate_request
|
|
}
|
|
};
|
|
|
|
let infer_clone = infer.clone();
|
|
let compute_type_clone = compute_type.clone();
|
|
let span_clone = span.clone();
|
|
|
|
futures.push(async move {
|
|
generate_internal(
|
|
Extension(infer_clone),
|
|
compute_type_clone,
|
|
Json(generate_request),
|
|
span_clone,
|
|
)
|
|
.await
|
|
.map(|(_, Json(generation))| generation.generated_text)
|
|
.map_err(|_| {
|
|
(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(ErrorResponse {
|
|
error: "Incomplete generation".into(),
|
|
error_type: "Incomplete generation".into(),
|
|
}),
|
|
)
|
|
})
|
|
});
|
|
}
|
|
|
|
// execute all futures in parallel, collect results, returning early if any error occurs
|
|
let results = futures::future::join_all(futures).await;
|
|
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
|
let predictions = predictions?;
|
|
|
|
let response = VertexResponse { predictions };
|
|
Ok((HeaderMap::new(), Json(response)).into_response())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::{Message, MessageContent};
|
|
|
|
#[test]
|
|
fn vertex_deserialization() {
|
|
let string = serde_json::json!({
|
|
|
|
"instances": [
|
|
{
|
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
|
"max_tokens": 128,
|
|
"top_p": 0.95,
|
|
"temperature": 0.7
|
|
}
|
|
]
|
|
|
|
});
|
|
let request: VertexRequest = serde_json::from_value(string).expect("Can deserialize");
|
|
assert_eq!(
|
|
request,
|
|
VertexRequest {
|
|
instances: vec![VertexInstance::Chat(ChatRequest {
|
|
messages: vec![Message {
|
|
role: "user".to_string(),
|
|
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
|
|
name: None,
|
|
},],
|
|
max_tokens: Some(128),
|
|
top_p: Some(0.95),
|
|
temperature: Some(0.7),
|
|
..Default::default()
|
|
})]
|
|
}
|
|
);
|
|
}
|
|
}
|