diff --git a/Cargo.lock b/Cargo.lock index d298c379..aa5cb642 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2174,6 +2174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d" dependencies = [ "serde", + "serde_json", ] [[package]] diff --git a/router/Cargo.toml b/router/Cargo.toml index 7773e212..45acab8e 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -46,7 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ "opentelemetry-otlp", ] } -minijinja = { version = "2.0.2" } +minijinja = { version = "2.0.2", features = ["json"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" regex = "1.10.3" diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index a8537818..63bc8c1b 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1,9 +1,7 @@ use std::collections::HashSet; use crate::infer::InferError; -use crate::{ - ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken, -}; +use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -32,6 +30,7 @@ impl ChatTemplate { env.set_unknown_method_callback(pycompat::unknown_method_callback); let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); + tracing::debug!("Loading template: {:#?}", template_str); // leaking env and template_str as read-only, static resources for performance. let template = Box::leak(env) @@ -42,6 +41,7 @@ impl ChatTemplate { let variables = template.undeclared_variables(true); // check if the `tools` variable is used in the template let use_default_tool_template = !variables.contains("tools"); + tracing::debug!("Use default tool template: {}", use_default_tool_template); Self { template, @@ -56,36 +56,43 @@ impl ChatTemplate { &self, guideline: Option<&str>, mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, + tools_and_prompt: Option<(Option>, String)>, ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - }); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - // check if guideline is expected but not provided if self.variables.contains("guideline") && guideline.is_none() { return Err(InferError::MissingTemplateVariable("guideline".to_string())); } - self.template + let (tools, tool_prompt) = tools_and_prompt.unwrap_or_default(); + + if tools.is_some() { + // check if the `tools` variable is used in the template + // if not, we need to append the tools to the last message + let text = if self.use_default_tool_template { + format!("\n---\n{:?}\n{}", tools, tool_prompt) + } else { + // if the `tools` variable is used in the template, we just append the tool_prompt + format!("\n---\n{}", tool_prompt) + }; + if let Some(last_message) = messages.last_mut() { + last_message.content.push(MessageChunk::Text { text }); + } + } + + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); + + return self + .template .render(ChatTemplateInputs { guideline, messages, bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), add_generation_prompt: true, - tools: None, + tools: tools, tools_prompt: None, }) - .map_err(InferError::TemplateError) + .map_err(InferError::TemplateError); } } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index c9354d9a..5329dc72 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -3,7 +3,7 @@ mod chat_template; pub mod tool_grammar; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; -use crate::GrammarType; +use crate::Tool; use crate::{ ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, Message, PrefillToken, Token, @@ -140,12 +140,12 @@ impl Infer { &self, guideline: Option, messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, + tools_and_prompt: Option<(Option>, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(guideline.as_deref(), messages, grammar_with_prompt) + .apply(guideline.as_deref(), messages, tools_and_prompt) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); diff --git a/router/src/lib.rs b/router/src/lib.rs index 1b2ff153..48aaf682 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -843,7 +843,7 @@ pub(crate) struct ChatRequest { #[serde(default = "default_tool_prompt")] #[schema( nullable = true, - example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"" + example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables." )] pub tool_prompt: Option, @@ -867,7 +867,7 @@ pub(crate) struct ChatRequest { fn default_tool_prompt() -> Option { Some( - "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), + "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string(), ) } @@ -968,7 +968,7 @@ pub(crate) struct ChatTemplateInputs<'a> { bos_token: Option<&'a str>, eos_token: Option<&'a str>, add_generation_prompt: bool, - tools: Option<&'a str>, + tools: Option>, tools_prompt: Option<&'a str>, guideline: Option<&'a str>, } diff --git a/router/src/server.rs b/router/src/server.rs index 8ec7a871..27f0287a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2562,13 +2562,11 @@ fn prepare_chat_input( } // if tools are set, apply the tool grammar and then the chat template - let tool_grammar: Option = ToolGrammar::apply(tools, tool_choice)?; + let tool_grammar: Option = ToolGrammar::apply(tools.clone(), tool_choice)?; let grammar = tool_grammar .as_ref() .map(|t| GrammarType::Json(serde_json::json!(t))); - let tools_grammar_prompt = tool_grammar - .as_ref() - .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); - let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?; + let tools_and_prompt: (Option>, String) = (tools, tool_prompt.into()); + let inputs = infer.apply_chat_template(guideline, messages, Some(tools_and_prompt))?; Ok((inputs, grammar, tool_grammar)) }