text-generation-inference/router/src/vertex.rs

361 lines
13 KiB
Rust
Raw Normal View History

use crate::infer::Infer;
use crate::server::{generate_internal, ComputeType};
use crate::{
ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message,
StreamOptions, Tool, ToolChoice,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::Json;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use utoipa::ToSchema;
#[derive(Clone, Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct GenerateVertexInstance {
#[schema(example = "What is Deep Learning?")]
pub inputs: String,
#[schema(nullable = true, default = "null", example = "null")]
pub parameters: Option<GenerateParameters>,
}
#[derive(Clone, Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct VertexChat {
messages: Vec<Message>,
// Messages is ignored there.
#[serde(default)]
parameters: VertexParameters,
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct VertexParameters {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: Option<String>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,
/// UNUSED
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
/// result in a ban or exclusive selection of the relevant token.
#[serde(default)]
pub logit_bias: Option<Vec<f32>>,
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message.
#[serde(default)]
#[schema(example = "false")]
pub logprobs: Option<bool>,
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)]
#[schema(example = "5")]
pub top_logprobs: Option<u32>,
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
#[schema(example = "32")]
pub max_tokens: Option<u32>,
/// UNUSED
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
#[serde(default)]
#[schema(nullable = true, example = "2")]
pub n: Option<u32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics
#[serde(default)]
#[schema(nullable = true, example = 0.1)]
pub presence_penalty: Option<f32>,
/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stop: Option<Vec<String>>,
#[serde(default = "bool::default")]
pub stream: bool,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
/// lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
#[serde(default)]
#[schema(nullable = true, example = 1.0)]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
#[serde(default)]
#[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>,
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
/// functions the model may generate JSON inputs for.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools
#[serde(default)]
#[schema(
nullable = true,
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
)]
pub tool_prompt: Option<String>,
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tool_choice: ToolChoice,
/// Response format constraints for the generation.
///
/// NOTE: A request can use `response_format` OR `tools` but not both.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>,
/// A guideline to be used in the chat_template
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub guideline: Option<String>,
/// Options for streaming response. Only set this when you set stream: true.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stream_options: Option<StreamOptions>,
}
impl From<VertexChat> for ChatRequest {
fn from(val: VertexChat) -> Self {
Self {
messages: val.messages,
frequency_penalty: val.parameters.frequency_penalty,
guideline: val.parameters.guideline,
logit_bias: val.parameters.logit_bias,
logprobs: val.parameters.logprobs,
max_tokens: val.parameters.max_tokens,
model: val.parameters.model,
n: val.parameters.n,
presence_penalty: val.parameters.presence_penalty,
response_format: val.parameters.response_format,
seed: val.parameters.seed,
stop: val.parameters.stop,
stream_options: val.parameters.stream_options,
stream: val.parameters.stream,
temperature: val.parameters.temperature,
tool_choice: val.parameters.tool_choice,
tool_prompt: val.parameters.tool_prompt,
tools: val.parameters.tools,
top_logprobs: val.parameters.top_logprobs,
top_p: val.parameters.top_p,
}
}
}
#[derive(Clone, Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
#[serde(untagged)]
pub(crate) enum VertexInstance {
Generate(GenerateVertexInstance),
Chat(VertexChat),
}
#[derive(Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct VertexRequest {
#[serde(rename = "instances")]
pub instances: Vec<VertexInstance>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct VertexResponse {
pub predictions: Vec<String>,
}
/// Generate tokens from Vertex request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/vertex",
request_body = VertexRequest,
responses(
(status = 200, description = "Generated Text", body = VertexResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(
skip_all,
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
)]
pub(crate) async fn vertex_compatibility(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
// check that theres at least one instance
if req.instances.is_empty() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
));
}
// Prepare futures for all instances
let mut futures = Vec::with_capacity(req.instances.len());
for instance in req.instances.into_iter() {
let generate_request = match instance {
VertexInstance::Generate(instance) => GenerateRequest {
inputs: instance.inputs.clone(),
add_special_tokens: true,
parameters: GenerateParameters {
do_sample: true,
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
seed: instance.parameters.as_ref().and_then(|p| p.seed),
details: true,
decoder_input_details: true,
..Default::default()
},
},
VertexInstance::Chat(instance) => {
let chat_request: ChatRequest = instance.into();
let (generate_request, _using_tools): (GenerateRequest, bool) =
chat_request.try_into_generate(&infer)?;
generate_request
}
};
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
futures.push(async move {
generate_internal(
Extension(infer_clone),
compute_type_clone,
Json(generate_request),
span_clone,
)
.await
.map(|(_, Json(generation))| generation.generated_text)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
});
}
// execute all futures in parallel, collect results, returning early if any error occurs
let results = futures::future::join_all(futures).await;
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
let predictions = predictions?;
let response = VertexResponse { predictions };
Ok((HeaderMap::new(), Json(response)).into_response())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Message, MessageContent};
#[test]
fn vertex_deserialization() {
let string = serde_json::json!({
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
"parameters": {
"max_tokens": 128,
"top_p": 0.95,
"temperature": 0.7
}
});
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
let string = serde_json::json!({
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
});
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
let string = serde_json::json!({
"instances": [
{
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
"parameters": {
"max_tokens": 128,
"top_p": 0.95,
"temperature": 0.7
}
}
]
});
let request: VertexRequest = serde_json::from_value(string).expect("Can deserialize");
assert_eq!(
request,
VertexRequest {
instances: vec![VertexInstance::Chat(VertexChat {
messages: vec![Message {
role: "user".to_string(),
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
name: None,
},],
parameters: VertexParameters {
max_tokens: Some(128),
top_p: Some(0.95),
temperature: Some(0.7),
..Default::default()
}
})]
}
);
}
}