diff --git a/router/src/lib.rs b/router/src/lib.rs index 3905a1ec..c0f86b11 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -62,7 +62,7 @@ pub(crate) struct GenerateVertexInstance { pub parameters: Option, } -#[derive(Clone, Deserialize, ToSchema, Serialize)] +#[derive(Clone, Deserialize, ToSchema, Serialize, Default)] pub(crate) struct ChatRequestParameters { #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. diff --git a/router/src/server.rs b/router/src/server.rs index 4647b150..20329e6d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1452,43 +1452,27 @@ async fn vertex_compatibility( }, VertexInstance::Chat(instance) => { let messages = instance.messages; - let ChatRequestParameters { - model, - max_tokens, - seed, - stop, - stream, - tools, - tool_choice, - tool_prompt, - temperature, - response_format, - guideline, - presence_penalty, - frequency_penalty, - top_p, - top_logprobs, - .. - } = instance.parameters.unwrap(); + let chat_request: ChatRequestParameters = instance.parameters.unwrap_or_default(); - let repetition_penalty = presence_penalty.map(|x| x + 2.0); - let max_new_tokens = max_tokens.or(Some(100)); - let tool_prompt = tool_prompt + let repetition_penalty = chat_request.presence_penalty.map(|x| x + 2.0); + let max_new_tokens = chat_request.max_tokens.or(Some(100)); + let tool_prompt = chat_request + .tool_prompt .filter(|s| !s.is_empty()) .unwrap_or_else(default_tool_prompt); - let stop = stop.unwrap_or_default(); + let stop = chat_request.stop.unwrap_or_default(); // enable greedy only when temperature is 0 - let (do_sample, temperature) = match temperature { + let (do_sample, temperature) = match chat_request.temperature { Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; let (inputs, grammar, _using_tools) = match prepare_chat_input( &infer, - response_format, - tools, - tool_choice, + chat_request.response_format, + chat_request.tools, + chat_request.tool_choice, &tool_prompt, - guideline, + chat_request.guideline, messages, ) { Ok(result) => result, @@ -1510,9 +1494,9 @@ async fn vertex_compatibility( best_of: None, temperature, repetition_penalty, - frequency_penalty, + frequency_penalty: chat_request.frequency_penalty, top_k: None, - top_p, + top_p: chat_request.top_p, typical_p: None, do_sample, max_new_tokens, @@ -1521,11 +1505,11 @@ async fn vertex_compatibility( truncate: None, watermark: false, details: true, - decoder_input_details: !stream, - seed, - top_n_tokens: top_logprobs, + decoder_input_details: !chat_request.stream, + seed: chat_request.seed, + top_n_tokens: chat_request.top_logprobs, grammar, - adapter_id: model.filter(|m| *m != "tgi").map(String::from), + adapter_id: chat_request.model.filter(|m| *m != "tgi").map(String::from), }, } }