From 9874b15fa87f6ef00420b86b4051f40d001c326c Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 9 Apr 2024 00:37:05 +0000 Subject: [PATCH] fix: adjust tool grammar ownership --- router/src/server.rs | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 427ce894..c7c42d77 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -756,19 +756,34 @@ async fn chat_completions( ) -> Result)> { metrics::increment_counter!("tgi_request_count"); - let stream = req.stream; - let max_new_tokens = req.max_tokens.or(Some(100)); - let repetition_penalty = req - .presence_penalty - // rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0) - .map(|x| x + 2.0); - let logprobs = req.logprobs.unwrap_or(false); - let seed = req.seed; - let stop = req.stop.unwrap_or_default(); - let tool_prompt = req.tool_prompt.unwrap_or_default(); + let ChatRequest { + frequency_penalty: _, + logit_bias: _, + logprobs, + max_tokens, + messages, + model: _, + n: _, + presence_penalty, + seed, + stop, + stream, + temperature: _, + tools, + tool_choice, + tool_prompt, + top_p: _, + top_logprobs: _, + } = req; + + let repetition_penalty = presence_penalty.map(|x| x + 2.0); + let max_new_tokens = max_tokens.or(Some(100)); + let logprobs = logprobs.unwrap_or(false); + let tool_prompt = tool_prompt.unwrap_or_default(); + let stop = stop.unwrap_or_default(); // extract tool grammar if present - let tool_grammar = match ToolGrammar::apply(req.tools.as_ref(), req.tool_choice.as_ref()) { + let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { Ok(grammar) => grammar, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); @@ -784,7 +799,7 @@ async fn chat_completions( }; // apply chat template to flatten the request into a single input - let mut inputs = match infer.apply_chat_template(req.messages) { + let mut inputs = match infer.apply_chat_template(messages) { Ok(inputs) => inputs, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation");