mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fix /vertex
payload parsing when MESSAGES_API_ENABLED
This commit is contained in:
parent
9263817c71
commit
8ef3da72e1
@ -62,11 +62,132 @@ pub(crate) struct GenerateVertexInstance {
|
|||||||
pub parameters: Option<GenerateParameters>,
|
pub parameters: Option<GenerateParameters>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
pub(crate) struct ChatRequestParameters {
|
||||||
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
|
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||||
|
pub model: Option<String>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "1.0")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// UNUSED
|
||||||
|
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
||||||
|
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
||||||
|
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
||||||
|
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
||||||
|
/// result in a ban or exclusive selection of the relevant token.
|
||||||
|
#[serde(default)]
|
||||||
|
pub logit_bias: Option<Vec<f32>>,
|
||||||
|
|
||||||
|
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||||||
|
/// output token returned in the content of message.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "false")]
|
||||||
|
pub logprobs: Option<bool>,
|
||||||
|
|
||||||
|
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
||||||
|
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "5")]
|
||||||
|
pub top_logprobs: Option<u32>,
|
||||||
|
|
||||||
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "32")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// UNUSED
|
||||||
|
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
||||||
|
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "2")]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
||||||
|
/// increasing the model's likelihood to talk about new topics
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 0.1)]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[serde(default = "bool::default")]
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
#[schema(nullable = true, example = 42)]
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
|
||||||
|
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
||||||
|
/// lower values like 0.2 will make it more focused and deterministic.
|
||||||
|
///
|
||||||
|
/// We generally recommend altering this or `top_p` but not both.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 1.0)]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
|
||||||
|
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 0.95)]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
|
||||||
|
/// functions the model may generate JSON inputs for.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
|
||||||
|
/// A prompt to be appended before the tools
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(
|
||||||
|
nullable = true,
|
||||||
|
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||||||
|
)]
|
||||||
|
pub tool_prompt: Option<String>,
|
||||||
|
|
||||||
|
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub tool_choice: ToolChoice,
|
||||||
|
|
||||||
|
/// Response format constraints for the generation.
|
||||||
|
///
|
||||||
|
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub response_format: Option<GrammarType>,
|
||||||
|
|
||||||
|
/// A guideline to be used in the chat_template
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub guideline: Option<String>,
|
||||||
|
|
||||||
|
/// Options for streaming response. Only set this when you set stream: true.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub stream_options: Option<StreamOptions>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct ChatVertexInstance {
|
||||||
|
#[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")]
|
||||||
|
pub messages: Vec<Message>,
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub parameters: Option<ChatRequestParameters>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
enum VertexInstance {
|
enum VertexInstance {
|
||||||
Generate(GenerateVertexInstance),
|
Generate(GenerateVertexInstance),
|
||||||
Chat(ChatRequest),
|
Chat(ChatVertexInstance),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, ToSchema)]
|
#[derive(Deserialize, ToSchema)]
|
||||||
|
@ -8,7 +8,7 @@ use crate::kserve::{
|
|||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
|
use crate::{default_tool_prompt, ChatRequestParameters, ChatTokenizeResponse, VertexInstance};
|
||||||
use crate::{
|
use crate::{
|
||||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
@ -1436,10 +1436,10 @@ async fn vertex_compatibility(
|
|||||||
// Prepare futures for all instances
|
// Prepare futures for all instances
|
||||||
let mut futures = Vec::with_capacity(req.instances.len());
|
let mut futures = Vec::with_capacity(req.instances.len());
|
||||||
|
|
||||||
for instance in req.instances.iter() {
|
for instance in req.instances.into_iter() {
|
||||||
let generate_request = match instance {
|
let generate_request = match instance {
|
||||||
VertexInstance::Generate(instance) => GenerateRequest {
|
VertexInstance::Generate(instance) => GenerateRequest {
|
||||||
inputs: instance.inputs.clone(),
|
inputs: instance.inputs,
|
||||||
add_special_tokens: true,
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
do_sample: true,
|
do_sample: true,
|
||||||
@ -1451,10 +1451,10 @@ async fn vertex_compatibility(
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
VertexInstance::Chat(instance) => {
|
VertexInstance::Chat(instance) => {
|
||||||
let ChatRequest {
|
let messages = instance.messages;
|
||||||
|
let ChatRequestParameters {
|
||||||
model,
|
model,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
messages,
|
|
||||||
seed,
|
seed,
|
||||||
stop,
|
stop,
|
||||||
stream,
|
stream,
|
||||||
@ -1469,7 +1469,7 @@ async fn vertex_compatibility(
|
|||||||
top_p,
|
top_p,
|
||||||
top_logprobs,
|
top_logprobs,
|
||||||
..
|
..
|
||||||
} = instance.clone();
|
} = instance.parameters.unwrap();
|
||||||
|
|
||||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||||
let max_new_tokens = max_tokens.or(Some(100));
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
|
Loading…
Reference in New Issue
Block a user