fix[router]: Fix tools not passed in chat template

Signed-off-by: GitHub <noreply@github.com>
This commit is contained in:
Simone Rossi 2024-08-22 15:48:37 +00:00 committed by GitHub
parent 358ceb67dd
commit 9a3e838079
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 37 additions and 31 deletions

1
Cargo.lock generated
View File

@ -2174,6 +2174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d"
dependencies = [
"serde",
"serde_json",
]
[[package]]

View File

@ -46,7 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = [
"opentelemetry-otlp",
] }
minijinja = { version = "2.0.2" }
minijinja = { version = "2.0.2", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
futures-util = "0.3.30"
regex = "1.10.3"

View File

@ -1,9 +1,7 @@
use std::collections::HashSet;
use crate::infer::InferError;
use crate::{
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
};
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
@ -32,6 +30,7 @@ impl ChatTemplate {
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
tracing::debug!("Loading template: {:#?}", template_str);
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
@ -42,6 +41,7 @@ impl ChatTemplate {
let variables = template.undeclared_variables(true);
// check if the `tools` variable is used in the template
let use_default_tool_template = !variables.contains("tools");
tracing::debug!("Use default tool template: {}", use_default_tool_template);
Self {
template,
@ -56,36 +56,43 @@ impl ChatTemplate {
&self,
guideline: Option<&str>,
mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
tools_and_prompt: Option<(Option<Vec<Tool>>, String)>,
) -> Result<String, InferError> {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
});
}
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
// check if guideline is expected but not provided
if self.variables.contains("guideline") && guideline.is_none() {
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
}
self.template
let (tools, tool_prompt) = tools_and_prompt.unwrap_or_default();
if tools.is_some() {
// check if the `tools` variable is used in the template
// if not, we need to append the tools to the last message
let text = if self.use_default_tool_template {
format!("\n---\n{:?}\n{}", tools, tool_prompt)
} else {
// if the `tools` variable is used in the template, we just append the tool_prompt
format!("\n---\n{}", tool_prompt)
};
if let Some(last_message) = messages.last_mut() {
last_message.content.push(MessageChunk::Text { text });
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
return self
.template
.render(ChatTemplateInputs {
guideline,
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools: None,
tools: tools,
tools_prompt: None,
})
.map_err(InferError::TemplateError)
.map_err(InferError::TemplateError);
}
}

View File

@ -3,7 +3,7 @@ mod chat_template;
pub mod tool_grammar;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::GrammarType;
use crate::Tool;
use crate::{
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
Message, PrefillToken, Token,
@ -140,12 +140,12 @@ impl Infer {
&self,
guideline: Option<String>,
messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
tools_and_prompt: Option<(Option<Vec<Tool>>, String)>,
) -> Result<String, InferError> {
self.chat_template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(guideline.as_deref(), messages, grammar_with_prompt)
.apply(guideline.as_deref(), messages, tools_and_prompt)
.map_err(|e| {
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
tracing::error!("{e}");

View File

@ -843,7 +843,7 @@ pub(crate) struct ChatRequest {
#[serde(default = "default_tool_prompt")]
#[schema(
nullable = true,
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
)]
pub tool_prompt: Option<String>,
@ -867,7 +867,7 @@ pub(crate) struct ChatRequest {
fn default_tool_prompt() -> Option<String> {
Some(
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string(),
)
}
@ -968,7 +968,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
bos_token: Option<&'a str>,
eos_token: Option<&'a str>,
add_generation_prompt: bool,
tools: Option<&'a str>,
tools: Option<Vec<Tool>>,
tools_prompt: Option<&'a str>,
guideline: Option<&'a str>,
}

View File

@ -2562,13 +2562,11 @@ fn prepare_chat_input(
}
// if tools are set, apply the tool grammar and then the chat template
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools, tool_choice)?;
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools.clone(), tool_choice)?;
let grammar = tool_grammar
.as_ref()
.map(|t| GrammarType::Json(serde_json::json!(t)));
let tools_grammar_prompt = tool_grammar
.as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?;
let tools_and_prompt: (Option<Vec<Tool>>, String) = (tools, tool_prompt.into());
let inputs = infer.apply_chat_template(guideline, messages, Some(tools_and_prompt))?;
Ok((inputs, grammar, tool_grammar))
}