fix[router]: Fix tools not passed in chat template

Signed-off-by: GitHub <noreply@github.com>
This commit is contained in:
Simone Rossi 2024-08-22 15:48:37 +00:00 committed by GitHub
parent 358ceb67dd
commit 9a3e838079
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 37 additions and 31 deletions

1
Cargo.lock generated
View File

@ -2174,6 +2174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d" checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d"
dependencies = [ dependencies = [
"serde", "serde",
"serde_json",
] ]
[[package]] [[package]]

View File

@ -46,7 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = [ init-tracing-opentelemetry = { version = "0.14.1", features = [
"opentelemetry-otlp", "opentelemetry-otlp",
] } ] }
minijinja = { version = "2.0.2" } minijinja = { version = "2.0.2", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
futures-util = "0.3.30" futures-util = "0.3.30"
regex = "1.10.3" regex = "1.10.3"

View File

@ -1,9 +1,7 @@
use std::collections::HashSet; use std::collections::HashSet;
use crate::infer::InferError; use crate::infer::InferError;
use crate::{ use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
};
use minijinja::{Environment, ErrorKind, Template}; use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat; use minijinja_contrib::pycompat;
@ -32,6 +30,7 @@ impl ChatTemplate {
env.set_unknown_method_callback(pycompat::unknown_method_callback); env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str(); let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception); env.add_function("raise_exception", raise_exception);
tracing::debug!("Loading template: {:#?}", template_str);
// leaking env and template_str as read-only, static resources for performance. // leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env) let template = Box::leak(env)
@ -42,6 +41,7 @@ impl ChatTemplate {
let variables = template.undeclared_variables(true); let variables = template.undeclared_variables(true);
// check if the `tools` variable is used in the template // check if the `tools` variable is used in the template
let use_default_tool_template = !variables.contains("tools"); let use_default_tool_template = !variables.contains("tools");
tracing::debug!("Use default tool template: {}", use_default_tool_template);
Self { Self {
template, template,
@ -56,36 +56,43 @@ impl ChatTemplate {
&self, &self,
guideline: Option<&str>, guideline: Option<&str>,
mut messages: Vec<Message>, mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>, tools_and_prompt: Option<(Option<Vec<Tool>>, String)>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
});
}
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
// check if guideline is expected but not provided // check if guideline is expected but not provided
if self.variables.contains("guideline") && guideline.is_none() { if self.variables.contains("guideline") && guideline.is_none() {
return Err(InferError::MissingTemplateVariable("guideline".to_string())); return Err(InferError::MissingTemplateVariable("guideline".to_string()));
} }
self.template let (tools, tool_prompt) = tools_and_prompt.unwrap_or_default();
if tools.is_some() {
// check if the `tools` variable is used in the template
// if not, we need to append the tools to the last message
let text = if self.use_default_tool_template {
format!("\n---\n{:?}\n{}", tools, tool_prompt)
} else {
// if the `tools` variable is used in the template, we just append the tool_prompt
format!("\n---\n{}", tool_prompt)
};
if let Some(last_message) = messages.last_mut() {
last_message.content.push(MessageChunk::Text { text });
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
return self
.template
.render(ChatTemplateInputs { .render(ChatTemplateInputs {
guideline, guideline,
messages, messages,
bos_token: self.bos_token.as_deref(), bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(), eos_token: self.eos_token.as_deref(),
add_generation_prompt: true, add_generation_prompt: true,
tools: None, tools: tools,
tools_prompt: None, tools_prompt: None,
}) })
.map_err(InferError::TemplateError) .map_err(InferError::TemplateError);
} }
} }

View File

@ -3,7 +3,7 @@ mod chat_template;
pub mod tool_grammar; pub mod tool_grammar;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::GrammarType; use crate::Tool;
use crate::{ use crate::{
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
Message, PrefillToken, Token, Message, PrefillToken, Token,
@ -140,12 +140,12 @@ impl Infer {
&self, &self,
guideline: Option<String>, guideline: Option<String>,
messages: Vec<Message>, messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>, tools_and_prompt: Option<(Option<Vec<Tool>>, String)>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
self.chat_template self.chat_template
.as_ref() .as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(guideline.as_deref(), messages, grammar_with_prompt) .apply(guideline.as_deref(), messages, tools_and_prompt)
.map_err(|e| { .map_err(|e| {
metrics::counter!("tgi_request_failure", "err" => "template").increment(1); metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
tracing::error!("{e}"); tracing::error!("{e}");

View File

@ -843,7 +843,7 @@ pub(crate) struct ChatRequest {
#[serde(default = "default_tool_prompt")] #[serde(default = "default_tool_prompt")]
#[schema( #[schema(
nullable = true, nullable = true,
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"" example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
)] )]
pub tool_prompt: Option<String>, pub tool_prompt: Option<String>,
@ -867,7 +867,7 @@ pub(crate) struct ChatRequest {
fn default_tool_prompt() -> Option<String> { fn default_tool_prompt() -> Option<String> {
Some( Some(
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string(),
) )
} }
@ -968,7 +968,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
bos_token: Option<&'a str>, bos_token: Option<&'a str>,
eos_token: Option<&'a str>, eos_token: Option<&'a str>,
add_generation_prompt: bool, add_generation_prompt: bool,
tools: Option<&'a str>, tools: Option<Vec<Tool>>,
tools_prompt: Option<&'a str>, tools_prompt: Option<&'a str>,
guideline: Option<&'a str>, guideline: Option<&'a str>,
} }

View File

@ -2562,13 +2562,11 @@ fn prepare_chat_input(
} }
// if tools are set, apply the tool grammar and then the chat template // if tools are set, apply the tool grammar and then the chat template
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools, tool_choice)?; let tool_grammar: Option<Tools> = ToolGrammar::apply(tools.clone(), tool_choice)?;
let grammar = tool_grammar let grammar = tool_grammar
.as_ref() .as_ref()
.map(|t| GrammarType::Json(serde_json::json!(t))); .map(|t| GrammarType::Json(serde_json::json!(t)));
let tools_grammar_prompt = tool_grammar let tools_and_prompt: (Option<Vec<Tool>>, String) = (tools, tool_prompt.into());
.as_ref() let inputs = infer.apply_chat_template(guideline, messages, Some(tools_and_prompt))?;
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?;
Ok((inputs, grammar, tool_grammar)) Ok((inputs, grammar, tool_grammar))
} }