mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
361 lines
13 KiB
Rust
361 lines
13 KiB
Rust
|
use crate::infer::Infer;
|
||
|
use crate::server::{generate_internal, ComputeType};
|
||
|
use crate::{
|
||
|
ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message,
|
||
|
StreamOptions, Tool, ToolChoice,
|
||
|
};
|
||
|
use axum::extract::Extension;
|
||
|
use axum::http::{HeaderMap, StatusCode};
|
||
|
use axum::response::{IntoResponse, Response};
|
||
|
use axum::Json;
|
||
|
use serde::{Deserialize, Serialize};
|
||
|
use tracing::instrument;
|
||
|
use utoipa::ToSchema;
|
||
|
|
||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||
|
pub(crate) struct GenerateVertexInstance {
|
||
|
#[schema(example = "What is Deep Learning?")]
|
||
|
pub inputs: String,
|
||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||
|
pub parameters: Option<GenerateParameters>,
|
||
|
}
|
||
|
|
||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||
|
pub(crate) struct VertexChat {
|
||
|
messages: Vec<Message>,
|
||
|
// Messages is ignored there.
|
||
|
#[serde(default)]
|
||
|
parameters: VertexParameters,
|
||
|
}
|
||
|
|
||
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
|
||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||
|
pub(crate) struct VertexParameters {
|
||
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||
|
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||
|
pub model: Option<String>,
|
||
|
|
||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
||
|
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||
|
#[serde(default)]
|
||
|
#[schema(example = "1.0")]
|
||
|
pub frequency_penalty: Option<f32>,
|
||
|
|
||
|
/// UNUSED
|
||
|
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
||
|
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
||
|
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
||
|
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
||
|
/// result in a ban or exclusive selection of the relevant token.
|
||
|
#[serde(default)]
|
||
|
pub logit_bias: Option<Vec<f32>>,
|
||
|
|
||
|
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||
|
/// output token returned in the content of message.
|
||
|
#[serde(default)]
|
||
|
#[schema(example = "false")]
|
||
|
pub logprobs: Option<bool>,
|
||
|
|
||
|
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
||
|
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
||
|
#[serde(default)]
|
||
|
#[schema(example = "5")]
|
||
|
pub top_logprobs: Option<u32>,
|
||
|
|
||
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||
|
#[serde(default)]
|
||
|
#[schema(example = "32")]
|
||
|
pub max_tokens: Option<u32>,
|
||
|
|
||
|
/// UNUSED
|
||
|
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
||
|
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
||
|
#[serde(default)]
|
||
|
#[schema(nullable = true, example = "2")]
|
||
|
pub n: Option<u32>,
|
||
|
|
||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
||
|
/// increasing the model's likelihood to talk about new topics
|
||
|
#[serde(default)]
|
||
|
#[schema(nullable = true, example = 0.1)]
|
||
|
pub presence_penalty: Option<f32>,
|
||
|
|
||
|
/// Up to 4 sequences where the API will stop generating further tokens.
|
||
|
#[serde(default)]
|
||
|
#[schema(nullable = true, example = "null")]
|
||
|
pub stop: Option<Vec<String>>,
|
||
|
|
||
|
#[serde(default = "bool::default")]
|
||
|
pub stream: bool,
|
||
|
|
||
|
#[schema(nullable = true, example = 42)]
|
||
|
pub seed: Option<u64>,
|
||
|
|
||
|
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
||
|
/// lower values like 0.2 will make it more focused and deterministic.
|
||
|
///
|
||
|
/// We generally recommend altering this or `top_p` but not both.
|
||
|
#[serde(default)]
|
||
|
#[schema(nullable = true, example = 1.0)]
|
||
|
pub temperature: Option<f32>,
|
||
|
|
||
|
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
|
||
|
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
||
|
#[serde(default)]
|
||
|
#[schema(nullable = true, example = 0.95)]
|
||
|
pub top_p: Option<f32>,
|
||
|
|
||
|
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
|
||
|
/// functions the model may generate JSON inputs for.
|
||
|
#[serde(default)]
|
||
|
#[schema(nullable = true, example = "null")]
|
||
|
pub tools: Option<Vec<Tool>>,
|
||
|
|
||
|
/// A prompt to be appended before the tools
|
||
|
#[serde(default)]
|
||
|
#[schema(
|
||
|
nullable = true,
|
||
|
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||
|
)]
|
||
|
pub tool_prompt: Option<String>,
|
||
|
|
||
|
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||
|
#[serde(default)]
|
||
|
#[schema(nullable = true, example = "null")]
|
||
|
pub tool_choice: ToolChoice,
|
||
|
|
||
|
/// Response format constraints for the generation.
|
||
|
///
|
||
|
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
||
|
#[serde(default)]
|
||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||
|
pub response_format: Option<GrammarType>,
|
||
|
|
||
|
/// A guideline to be used in the chat_template
|
||
|
#[serde(default)]
|
||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||
|
pub guideline: Option<String>,
|
||
|
|
||
|
/// Options for streaming response. Only set this when you set stream: true.
|
||
|
#[serde(default)]
|
||
|
#[schema(nullable = true, example = "null")]
|
||
|
pub stream_options: Option<StreamOptions>,
|
||
|
}
|
||
|
|
||
|
impl From<VertexChat> for ChatRequest {
|
||
|
fn from(val: VertexChat) -> Self {
|
||
|
Self {
|
||
|
messages: val.messages,
|
||
|
frequency_penalty: val.parameters.frequency_penalty,
|
||
|
guideline: val.parameters.guideline,
|
||
|
logit_bias: val.parameters.logit_bias,
|
||
|
logprobs: val.parameters.logprobs,
|
||
|
max_tokens: val.parameters.max_tokens,
|
||
|
model: val.parameters.model,
|
||
|
n: val.parameters.n,
|
||
|
presence_penalty: val.parameters.presence_penalty,
|
||
|
response_format: val.parameters.response_format,
|
||
|
seed: val.parameters.seed,
|
||
|
stop: val.parameters.stop,
|
||
|
stream_options: val.parameters.stream_options,
|
||
|
stream: val.parameters.stream,
|
||
|
temperature: val.parameters.temperature,
|
||
|
tool_choice: val.parameters.tool_choice,
|
||
|
tool_prompt: val.parameters.tool_prompt,
|
||
|
tools: val.parameters.tools,
|
||
|
top_logprobs: val.parameters.top_logprobs,
|
||
|
top_p: val.parameters.top_p,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||
|
#[serde(untagged)]
|
||
|
pub(crate) enum VertexInstance {
|
||
|
Generate(GenerateVertexInstance),
|
||
|
Chat(VertexChat),
|
||
|
}
|
||
|
|
||
|
#[derive(Deserialize, ToSchema)]
|
||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||
|
pub(crate) struct VertexRequest {
|
||
|
#[serde(rename = "instances")]
|
||
|
pub instances: Vec<VertexInstance>,
|
||
|
}
|
||
|
|
||
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||
|
pub(crate) struct VertexResponse {
|
||
|
pub predictions: Vec<String>,
|
||
|
}
|
||
|
|
||
|
/// Generate tokens from Vertex request
|
||
|
#[utoipa::path(
|
||
|
post,
|
||
|
tag = "Text Generation Inference",
|
||
|
path = "/vertex",
|
||
|
request_body = VertexRequest,
|
||
|
responses(
|
||
|
(status = 200, description = "Generated Text", body = VertexResponse),
|
||
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||
|
example = json ! ({"error": "Request failed during generation"})),
|
||
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||
|
example = json ! ({"error": "Model is overloaded"})),
|
||
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||
|
example = json ! ({"error": "Input validation error"})),
|
||
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||
|
example = json ! ({"error": "Incomplete generation"})),
|
||
|
)
|
||
|
)]
|
||
|
#[instrument(
|
||
|
skip_all,
|
||
|
fields(
|
||
|
total_time,
|
||
|
validation_time,
|
||
|
queue_time,
|
||
|
inference_time,
|
||
|
time_per_token,
|
||
|
seed,
|
||
|
)
|
||
|
)]
|
||
|
pub(crate) async fn vertex_compatibility(
|
||
|
Extension(infer): Extension<Infer>,
|
||
|
Extension(compute_type): Extension<ComputeType>,
|
||
|
Json(req): Json<VertexRequest>,
|
||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||
|
let span = tracing::Span::current();
|
||
|
metrics::counter!("tgi_request_count").increment(1);
|
||
|
|
||
|
// check that theres at least one instance
|
||
|
if req.instances.is_empty() {
|
||
|
return Err((
|
||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||
|
Json(ErrorResponse {
|
||
|
error: "Input validation error".to_string(),
|
||
|
error_type: "Input validation error".to_string(),
|
||
|
}),
|
||
|
));
|
||
|
}
|
||
|
|
||
|
// Prepare futures for all instances
|
||
|
let mut futures = Vec::with_capacity(req.instances.len());
|
||
|
|
||
|
for instance in req.instances.into_iter() {
|
||
|
let generate_request = match instance {
|
||
|
VertexInstance::Generate(instance) => GenerateRequest {
|
||
|
inputs: instance.inputs.clone(),
|
||
|
add_special_tokens: true,
|
||
|
parameters: GenerateParameters {
|
||
|
do_sample: true,
|
||
|
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||
|
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
||
|
details: true,
|
||
|
decoder_input_details: true,
|
||
|
..Default::default()
|
||
|
},
|
||
|
},
|
||
|
VertexInstance::Chat(instance) => {
|
||
|
let chat_request: ChatRequest = instance.into();
|
||
|
let (generate_request, _using_tools): (GenerateRequest, bool) =
|
||
|
chat_request.try_into_generate(&infer)?;
|
||
|
generate_request
|
||
|
}
|
||
|
};
|
||
|
|
||
|
let infer_clone = infer.clone();
|
||
|
let compute_type_clone = compute_type.clone();
|
||
|
let span_clone = span.clone();
|
||
|
|
||
|
futures.push(async move {
|
||
|
generate_internal(
|
||
|
Extension(infer_clone),
|
||
|
compute_type_clone,
|
||
|
Json(generate_request),
|
||
|
span_clone,
|
||
|
)
|
||
|
.await
|
||
|
.map(|(_, Json(generation))| generation.generated_text)
|
||
|
.map_err(|_| {
|
||
|
(
|
||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||
|
Json(ErrorResponse {
|
||
|
error: "Incomplete generation".into(),
|
||
|
error_type: "Incomplete generation".into(),
|
||
|
}),
|
||
|
)
|
||
|
})
|
||
|
});
|
||
|
}
|
||
|
|
||
|
// execute all futures in parallel, collect results, returning early if any error occurs
|
||
|
let results = futures::future::join_all(futures).await;
|
||
|
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
||
|
let predictions = predictions?;
|
||
|
|
||
|
let response = VertexResponse { predictions };
|
||
|
Ok((HeaderMap::new(), Json(response)).into_response())
|
||
|
}
|
||
|
|
||
|
#[cfg(test)]
|
||
|
mod tests {
|
||
|
use super::*;
|
||
|
use crate::{Message, MessageContent};
|
||
|
|
||
|
#[test]
|
||
|
fn vertex_deserialization() {
|
||
|
let string = serde_json::json!({
|
||
|
|
||
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||
|
"parameters": {
|
||
|
"max_tokens": 128,
|
||
|
"top_p": 0.95,
|
||
|
"temperature": 0.7
|
||
|
}
|
||
|
});
|
||
|
|
||
|
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
|
||
|
|
||
|
let string = serde_json::json!({
|
||
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||
|
});
|
||
|
|
||
|
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
|
||
|
|
||
|
let string = serde_json::json!({
|
||
|
|
||
|
"instances": [
|
||
|
{
|
||
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||
|
"parameters": {
|
||
|
"max_tokens": 128,
|
||
|
"top_p": 0.95,
|
||
|
"temperature": 0.7
|
||
|
}
|
||
|
}
|
||
|
]
|
||
|
|
||
|
});
|
||
|
let request: VertexRequest = serde_json::from_value(string).expect("Can deserialize");
|
||
|
assert_eq!(
|
||
|
request,
|
||
|
VertexRequest {
|
||
|
instances: vec![VertexInstance::Chat(VertexChat {
|
||
|
messages: vec![Message {
|
||
|
role: "user".to_string(),
|
||
|
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
|
||
|
name: None,
|
||
|
},],
|
||
|
parameters: VertexParameters {
|
||
|
max_tokens: Some(128),
|
||
|
top_p: Some(0.95),
|
||
|
temperature: Some(0.7),
|
||
|
..Default::default()
|
||
|
}
|
||
|
})]
|
||
|
}
|
||
|
);
|
||
|
}
|
||
|
}
|