2024-09-24 21:37:17 +00:00
|
|
|
use crate::infer::Infer;
|
|
|
|
use crate::server::{generate_internal, ComputeType};
|
2024-10-15 16:11:59 +00:00
|
|
|
use crate::{ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest};
|
2024-09-24 21:37:17 +00:00
|
|
|
use axum::extract::Extension;
|
|
|
|
use axum::http::{HeaderMap, StatusCode};
|
|
|
|
use axum::response::{IntoResponse, Response};
|
|
|
|
use axum::Json;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use tracing::instrument;
|
|
|
|
use utoipa::ToSchema;
|
|
|
|
|
|
|
|
#[derive(Clone, Deserialize, ToSchema)]
|
|
|
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
|
|
|
pub(crate) struct GenerateVertexInstance {
|
|
|
|
#[schema(example = "What is Deep Learning?")]
|
|
|
|
pub inputs: String,
|
|
|
|
#[schema(nullable = true, default = "null", example = "null")]
|
|
|
|
pub parameters: Option<GenerateParameters>,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone, Deserialize, ToSchema)]
|
|
|
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
|
|
|
#[serde(untagged)]
|
|
|
|
pub(crate) enum VertexInstance {
|
|
|
|
Generate(GenerateVertexInstance),
|
2024-10-15 16:11:59 +00:00
|
|
|
Chat(ChatRequest),
|
2024-09-24 21:37:17 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Deserialize, ToSchema)]
|
|
|
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
|
|
|
pub(crate) struct VertexRequest {
|
|
|
|
#[serde(rename = "instances")]
|
|
|
|
pub instances: Vec<VertexInstance>,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
|
|
|
pub(crate) struct VertexResponse {
|
|
|
|
pub predictions: Vec<String>,
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Generate tokens from Vertex request
|
|
|
|
#[utoipa::path(
|
|
|
|
post,
|
|
|
|
tag = "Text Generation Inference",
|
|
|
|
path = "/vertex",
|
|
|
|
request_body = VertexRequest,
|
|
|
|
responses(
|
|
|
|
(status = 200, description = "Generated Text", body = VertexResponse),
|
|
|
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
|
|
|
example = json ! ({"error": "Request failed during generation"})),
|
|
|
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
|
|
|
example = json ! ({"error": "Model is overloaded"})),
|
|
|
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
|
|
|
example = json ! ({"error": "Input validation error"})),
|
|
|
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
|
|
|
example = json ! ({"error": "Incomplete generation"})),
|
|
|
|
)
|
|
|
|
)]
|
|
|
|
#[instrument(
|
|
|
|
skip_all,
|
|
|
|
fields(
|
|
|
|
total_time,
|
|
|
|
validation_time,
|
|
|
|
queue_time,
|
|
|
|
inference_time,
|
|
|
|
time_per_token,
|
|
|
|
seed,
|
|
|
|
)
|
|
|
|
)]
|
|
|
|
pub(crate) async fn vertex_compatibility(
|
|
|
|
Extension(infer): Extension<Infer>,
|
|
|
|
Extension(compute_type): Extension<ComputeType>,
|
|
|
|
Json(req): Json<VertexRequest>,
|
|
|
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
|
|
|
let span = tracing::Span::current();
|
|
|
|
metrics::counter!("tgi_request_count").increment(1);
|
|
|
|
|
|
|
|
// check that theres at least one instance
|
|
|
|
if req.instances.is_empty() {
|
|
|
|
return Err((
|
|
|
|
StatusCode::UNPROCESSABLE_ENTITY,
|
|
|
|
Json(ErrorResponse {
|
|
|
|
error: "Input validation error".to_string(),
|
|
|
|
error_type: "Input validation error".to_string(),
|
|
|
|
}),
|
|
|
|
));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Prepare futures for all instances
|
|
|
|
let mut futures = Vec::with_capacity(req.instances.len());
|
|
|
|
|
|
|
|
for instance in req.instances.into_iter() {
|
|
|
|
let generate_request = match instance {
|
|
|
|
VertexInstance::Generate(instance) => GenerateRequest {
|
|
|
|
inputs: instance.inputs.clone(),
|
|
|
|
add_special_tokens: true,
|
|
|
|
parameters: GenerateParameters {
|
|
|
|
do_sample: true,
|
|
|
|
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
|
|
|
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
|
|
|
details: true,
|
|
|
|
decoder_input_details: true,
|
|
|
|
..Default::default()
|
|
|
|
},
|
|
|
|
},
|
|
|
|
VertexInstance::Chat(instance) => {
|
|
|
|
let (generate_request, _using_tools): (GenerateRequest, bool) =
|
2024-10-15 16:11:59 +00:00
|
|
|
instance.try_into_generate(&infer)?;
|
2024-09-24 21:37:17 +00:00
|
|
|
generate_request
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
let infer_clone = infer.clone();
|
|
|
|
let compute_type_clone = compute_type.clone();
|
|
|
|
let span_clone = span.clone();
|
|
|
|
|
|
|
|
futures.push(async move {
|
|
|
|
generate_internal(
|
|
|
|
Extension(infer_clone),
|
|
|
|
compute_type_clone,
|
|
|
|
Json(generate_request),
|
|
|
|
span_clone,
|
|
|
|
)
|
|
|
|
.await
|
2024-12-06 04:52:00 +00:00
|
|
|
.map(|(_, _, Json(generation))| generation.generated_text)
|
2024-09-24 21:37:17 +00:00
|
|
|
.map_err(|_| {
|
|
|
|
(
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
Json(ErrorResponse {
|
|
|
|
error: "Incomplete generation".into(),
|
|
|
|
error_type: "Incomplete generation".into(),
|
|
|
|
}),
|
|
|
|
)
|
|
|
|
})
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
// execute all futures in parallel, collect results, returning early if any error occurs
|
|
|
|
let results = futures::future::join_all(futures).await;
|
|
|
|
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
|
|
|
let predictions = predictions?;
|
|
|
|
|
|
|
|
let response = VertexResponse { predictions };
|
|
|
|
Ok((HeaderMap::new(), Json(response)).into_response())
|
|
|
|
}
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
|
use super::*;
|
2025-02-21 09:30:29 +00:00
|
|
|
use crate::{Message, MessageBody, MessageContent};
|
2024-09-24 21:37:17 +00:00
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn vertex_deserialization() {
|
|
|
|
let string = serde_json::json!({
|
|
|
|
|
|
|
|
"instances": [
|
|
|
|
{
|
|
|
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
2024-10-15 16:11:59 +00:00
|
|
|
"max_tokens": 128,
|
|
|
|
"top_p": 0.95,
|
|
|
|
"temperature": 0.7
|
2024-09-24 21:37:17 +00:00
|
|
|
}
|
|
|
|
]
|
|
|
|
|
|
|
|
});
|
|
|
|
let request: VertexRequest = serde_json::from_value(string).expect("Can deserialize");
|
|
|
|
assert_eq!(
|
|
|
|
request,
|
|
|
|
VertexRequest {
|
2024-10-15 16:11:59 +00:00
|
|
|
instances: vec![VertexInstance::Chat(ChatRequest {
|
2024-09-24 21:37:17 +00:00
|
|
|
messages: vec![Message {
|
|
|
|
name: None,
|
2025-02-21 09:30:29 +00:00
|
|
|
role: "user".to_string(),
|
|
|
|
body: MessageBody::Content {
|
|
|
|
content: MessageContent::SingleText(
|
|
|
|
"What's Deep Learning?".to_string()
|
|
|
|
)
|
|
|
|
},
|
2024-09-24 21:37:17 +00:00
|
|
|
},],
|
2024-10-15 16:11:59 +00:00
|
|
|
max_tokens: Some(128),
|
|
|
|
top_p: Some(0.95),
|
|
|
|
temperature: Some(0.7),
|
|
|
|
..Default::default()
|
2024-09-24 21:37:17 +00:00
|
|
|
})]
|
|
|
|
}
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|