Fix /vertex payload parsing when MESSAGES_API_ENABLED

This commit is contained in:
Alvaro Bartolome 2024-09-23 20:39:20 +02:00
parent 9263817c71
commit 8ef3da72e1
No known key found for this signature in database
2 changed files with 128 additions and 7 deletions

View File

@ -62,11 +62,132 @@ pub(crate) struct GenerateVertexInstance {
pub parameters: Option<GenerateParameters>, pub parameters: Option<GenerateParameters>,
} }
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct ChatRequestParameters {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: Option<String>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,
/// UNUSED
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
/// result in a ban or exclusive selection of the relevant token.
#[serde(default)]
pub logit_bias: Option<Vec<f32>>,
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message.
#[serde(default)]
#[schema(example = "false")]
pub logprobs: Option<bool>,
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)]
#[schema(example = "5")]
pub top_logprobs: Option<u32>,
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
#[schema(example = "32")]
pub max_tokens: Option<u32>,
/// UNUSED
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
#[serde(default)]
#[schema(nullable = true, example = "2")]
pub n: Option<u32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics
#[serde(default)]
#[schema(nullable = true, example = 0.1)]
pub presence_penalty: Option<f32>,
/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stop: Option<Vec<String>>,
#[serde(default = "bool::default")]
pub stream: bool,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
/// lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
#[serde(default)]
#[schema(nullable = true, example = 1.0)]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
#[serde(default)]
#[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>,
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
/// functions the model may generate JSON inputs for.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools
#[serde(default)]
#[schema(
nullable = true,
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
)]
pub tool_prompt: Option<String>,
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tool_choice: ToolChoice,
/// Response format constraints for the generation.
///
/// NOTE: A request can use `response_format` OR `tools` but not both.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>,
/// A guideline to be used in the chat_template
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub guideline: Option<String>,
/// Options for streaming response. Only set this when you set stream: true.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stream_options: Option<StreamOptions>,
}
#[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct ChatVertexInstance {
#[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")]
pub messages: Vec<Message>,
#[schema(nullable = true, default = "null", example = "null")]
pub parameters: Option<ChatRequestParameters>,
}
#[derive(Clone, Deserialize, ToSchema)] #[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)] #[serde(untagged)]
enum VertexInstance { enum VertexInstance {
Generate(GenerateVertexInstance), Generate(GenerateVertexInstance),
Chat(ChatRequest), Chat(ChatVertexInstance),
} }
#[derive(Deserialize, ToSchema)] #[derive(Deserialize, ToSchema)]

View File

@ -8,7 +8,7 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready, kserve_model_metadata, kserve_model_metadata_ready,
}; };
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance}; use crate::{default_tool_prompt, ChatRequestParameters, ChatTokenizeResponse, VertexInstance};
use crate::{ use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -1436,10 +1436,10 @@ async fn vertex_compatibility(
// Prepare futures for all instances // Prepare futures for all instances
let mut futures = Vec::with_capacity(req.instances.len()); let mut futures = Vec::with_capacity(req.instances.len());
for instance in req.instances.iter() { for instance in req.instances.into_iter() {
let generate_request = match instance { let generate_request = match instance {
VertexInstance::Generate(instance) => GenerateRequest { VertexInstance::Generate(instance) => GenerateRequest {
inputs: instance.inputs.clone(), inputs: instance.inputs,
add_special_tokens: true, add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
do_sample: true, do_sample: true,
@ -1451,10 +1451,10 @@ async fn vertex_compatibility(
}, },
}, },
VertexInstance::Chat(instance) => { VertexInstance::Chat(instance) => {
let ChatRequest { let messages = instance.messages;
let ChatRequestParameters {
model, model,
max_tokens, max_tokens,
messages,
seed, seed,
stop, stop,
stream, stream,
@ -1469,7 +1469,7 @@ async fn vertex_compatibility(
top_p, top_p,
top_logprobs, top_logprobs,
.. ..
} = instance.clone(); } = instance.parameters.unwrap();
let repetition_penalty = presence_penalty.map(|x| x + 2.0); let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100)); let max_new_tokens = max_tokens.or(Some(100));