From bb73acc1a978b7863b8c2fb503f3f56c228f36de Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 4 Apr 2024 01:10:06 +0000 Subject: [PATCH] feat: update default prompt and other small refactors --- router/src/lib.rs | 2 +- router/src/server.rs | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 020117c0..56bb0ba4 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -669,7 +669,7 @@ pub(crate) struct ChatRequest { #[serde(default = "default_tool_prompt")] #[schema( nullable = true, - example = "\"Based on the conversation, please choose the most appropriate tool to use: \"" + example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"" )] pub tool_prompt: Option, diff --git a/router/src/server.rs b/router/src/server.rs index 17fcedd6..427ce894 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -767,9 +767,9 @@ async fn chat_completions( let stop = req.stop.unwrap_or_default(); let tool_prompt = req.tool_prompt.unwrap_or_default(); - // apply chat template to flatten the request into a single input - let mut inputs = match infer.apply_chat_template(req.messages) { - Ok(inputs) => inputs, + // extract tool grammar if present + let tool_grammar = match ToolGrammar::apply(req.tools.as_ref(), req.tool_choice.as_ref()) { + Ok(grammar) => grammar, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); tracing::error!("{err}"); @@ -783,8 +783,9 @@ async fn chat_completions( } }; - let tool_grammar = match ToolGrammar::apply(req.tools.as_ref(), req.tool_choice.as_ref()) { - Ok(grammar) => grammar, + // apply chat template to flatten the request into a single input + let mut inputs = match infer.apply_chat_template(req.messages) { + Ok(inputs) => inputs, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); tracing::error!("{err}"); @@ -809,7 +810,7 @@ async fn chat_completions( ) })?; inputs = format!("{inputs}{tool_prompt}{tools_str}"); - Some(GrammarType::Json(serde_json::to_value(tools).unwrap())) + Some(GrammarType::Json(serde_json::json!(tools))) } else { None };