diff --git a/router/src/lib.rs b/router/src/lib.rs index ad8924df..3905a1ec 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -62,11 +62,132 @@ pub(crate) struct GenerateVertexInstance { pub parameters: Option, } +#[derive(Clone, Deserialize, ToSchema, Serialize)] +pub(crate) struct ChatRequestParameters { + #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] + /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. + pub model: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, + /// decreasing the model's likelihood to repeat the same line verbatim. + #[serde(default)] + #[schema(example = "1.0")] + pub frequency_penalty: Option, + + /// UNUSED + /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens + /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, + /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, + /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should + /// result in a ban or exclusive selection of the relevant token. + #[serde(default)] + pub logit_bias: Option>, + + /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each + /// output token returned in the content of message. + #[serde(default)] + #[schema(example = "false")] + pub logprobs: Option, + + /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with + /// an associated log probability. logprobs must be set to true if this parameter is used. + #[serde(default)] + #[schema(example = "5")] + pub top_logprobs: Option, + + /// The maximum number of tokens that can be generated in the chat completion. + #[serde(default)] + #[schema(example = "32")] + pub max_tokens: Option, + + /// UNUSED + /// How many chat completion choices to generate for each input message. Note that you will be charged based on the + /// number of generated tokens across all of the choices. Keep n as 1 to minimize costs. + #[serde(default)] + #[schema(nullable = true, example = "2")] + pub n: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, + /// increasing the model's likelihood to talk about new topics + #[serde(default)] + #[schema(nullable = true, example = 0.1)] + pub presence_penalty: Option, + + /// Up to 4 sequences where the API will stop generating further tokens. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub stop: Option>, + + #[serde(default = "bool::default")] + pub stream: bool, + + #[schema(nullable = true, example = 42)] + pub seed: Option, + + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while + /// lower values like 0.2 will make it more focused and deterministic. + /// + /// We generally recommend altering this or `top_p` but not both. + #[serde(default)] + #[schema(nullable = true, example = 1.0)] + pub temperature: Option, + + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the + /// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + #[serde(default)] + #[schema(nullable = true, example = 0.95)] + pub top_p: Option, + + /// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of + /// functions the model may generate JSON inputs for. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub tools: Option>, + + /// A prompt to be appended before the tools + #[serde(default)] + #[schema( + nullable = true, + example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables." + )] + pub tool_prompt: Option, + + /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub tool_choice: ToolChoice, + + /// Response format constraints for the generation. + /// + /// NOTE: A request can use `response_format` OR `tools` but not both. + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub response_format: Option, + + /// A guideline to be used in the chat_template + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub guideline: Option, + + /// Options for streaming response. Only set this when you set stream: true. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub stream_options: Option, +} + +#[derive(Clone, Deserialize, ToSchema)] +pub(crate) struct ChatVertexInstance { + #[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")] + pub messages: Vec, + #[schema(nullable = true, default = "null", example = "null")] + pub parameters: Option, +} + #[derive(Clone, Deserialize, ToSchema)] #[serde(untagged)] enum VertexInstance { Generate(GenerateVertexInstance), - Chat(ChatRequest), + Chat(ChatVertexInstance), } #[derive(Deserialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 32c86e0f..4647b150 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,7 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; -use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance}; +use crate::{default_tool_prompt, ChatRequestParameters, ChatTokenizeResponse, VertexInstance}; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -1436,10 +1436,10 @@ async fn vertex_compatibility( // Prepare futures for all instances let mut futures = Vec::with_capacity(req.instances.len()); - for instance in req.instances.iter() { + for instance in req.instances.into_iter() { let generate_request = match instance { VertexInstance::Generate(instance) => GenerateRequest { - inputs: instance.inputs.clone(), + inputs: instance.inputs, add_special_tokens: true, parameters: GenerateParameters { do_sample: true, @@ -1451,10 +1451,10 @@ async fn vertex_compatibility( }, }, VertexInstance::Chat(instance) => { - let ChatRequest { + let messages = instance.messages; + let ChatRequestParameters { model, max_tokens, - messages, seed, stop, stream, @@ -1469,7 +1469,7 @@ async fn vertex_compatibility( top_p, top_logprobs, .. - } = instance.clone(); + } = instance.parameters.unwrap(); let repetition_penalty = presence_penalty.map(|x| x + 2.0); let max_new_tokens = max_tokens.or(Some(100));