fix: adjust tool grammar ownership

This commit is contained in:
drbh 2024-04-09 00:37:05 +00:00
parent bb73acc1a9
commit 9874b15fa8

View File

@ -756,19 +756,34 @@ async fn chat_completions(
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let stream = req.stream; let ChatRequest {
let max_new_tokens = req.max_tokens.or(Some(100)); frequency_penalty: _,
let repetition_penalty = req logit_bias: _,
.presence_penalty logprobs,
// rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0) max_tokens,
.map(|x| x + 2.0); messages,
let logprobs = req.logprobs.unwrap_or(false); model: _,
let seed = req.seed; n: _,
let stop = req.stop.unwrap_or_default(); presence_penalty,
let tool_prompt = req.tool_prompt.unwrap_or_default(); seed,
stop,
stream,
temperature: _,
tools,
tool_choice,
tool_prompt,
top_p: _,
top_logprobs: _,
} = req;
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default();
let stop = stop.unwrap_or_default();
// extract tool grammar if present // extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(req.tools.as_ref(), req.tool_choice.as_ref()) { let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar, Ok(grammar) => grammar,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@ -784,7 +799,7 @@ async fn chat_completions(
}; };
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let mut inputs = match infer.apply_chat_template(req.messages) { let mut inputs = match infer.apply_chat_template(messages) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");