mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix[router]: Fix tools not passed in chat template
Signed-off-by: GitHub <noreply@github.com>
This commit is contained in:
parent
358ceb67dd
commit
9a3e838079
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2174,6 +2174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -46,7 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||
"opentelemetry-otlp",
|
||||
] }
|
||||
minijinja = { version = "2.0.2" }
|
||||
minijinja = { version = "2.0.2", features = ["json"] }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
|
@ -1,9 +1,7 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::infer::InferError;
|
||||
use crate::{
|
||||
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
|
||||
};
|
||||
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
|
||||
@ -32,6 +30,7 @@ impl ChatTemplate {
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
tracing::debug!("Loading template: {:#?}", template_str);
|
||||
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
@ -42,6 +41,7 @@ impl ChatTemplate {
|
||||
let variables = template.undeclared_variables(true);
|
||||
// check if the `tools` variable is used in the template
|
||||
let use_default_tool_template = !variables.contains("tools");
|
||||
tracing::debug!("Use default tool template: {}", use_default_tool_template);
|
||||
|
||||
Self {
|
||||
template,
|
||||
@ -56,36 +56,43 @@ impl ChatTemplate {
|
||||
&self,
|
||||
guideline: Option<&str>,
|
||||
mut messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
tools_and_prompt: Option<(Option<Vec<Tool>>, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content.push(MessageChunk::Text {
|
||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
|
||||
// check if guideline is expected but not provided
|
||||
if self.variables.contains("guideline") && guideline.is_none() {
|
||||
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
|
||||
}
|
||||
|
||||
self.template
|
||||
let (tools, tool_prompt) = tools_and_prompt.unwrap_or_default();
|
||||
|
||||
if tools.is_some() {
|
||||
// check if the `tools` variable is used in the template
|
||||
// if not, we need to append the tools to the last message
|
||||
let text = if self.use_default_tool_template {
|
||||
format!("\n---\n{:?}\n{}", tools, tool_prompt)
|
||||
} else {
|
||||
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||
format!("\n---\n{}", tool_prompt)
|
||||
};
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
last_message.content.push(MessageChunk::Text { text });
|
||||
}
|
||||
}
|
||||
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
|
||||
return self
|
||||
.template
|
||||
.render(ChatTemplateInputs {
|
||||
guideline,
|
||||
messages,
|
||||
bos_token: self.bos_token.as_deref(),
|
||||
eos_token: self.eos_token.as_deref(),
|
||||
add_generation_prompt: true,
|
||||
tools: None,
|
||||
tools: tools,
|
||||
tools_prompt: None,
|
||||
})
|
||||
.map_err(InferError::TemplateError)
|
||||
.map_err(InferError::TemplateError);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,7 @@ mod chat_template;
|
||||
pub mod tool_grammar;
|
||||
|
||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||
use crate::GrammarType;
|
||||
use crate::Tool;
|
||||
use crate::{
|
||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||
Message, PrefillToken, Token,
|
||||
@ -140,12 +140,12 @@ impl Infer {
|
||||
&self,
|
||||
guideline: Option<String>,
|
||||
messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
tools_and_prompt: Option<(Option<Vec<Tool>>, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
self.chat_template
|
||||
.as_ref()
|
||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||
.apply(guideline.as_deref(), messages, grammar_with_prompt)
|
||||
.apply(guideline.as_deref(), messages, tools_and_prompt)
|
||||
.map_err(|e| {
|
||||
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
||||
tracing::error!("{e}");
|
||||
|
@ -843,7 +843,7 @@ pub(crate) struct ChatRequest {
|
||||
#[serde(default = "default_tool_prompt")]
|
||||
#[schema(
|
||||
nullable = true,
|
||||
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
|
||||
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||||
)]
|
||||
pub tool_prompt: Option<String>,
|
||||
|
||||
@ -867,7 +867,7 @@ pub(crate) struct ChatRequest {
|
||||
|
||||
fn default_tool_prompt() -> Option<String> {
|
||||
Some(
|
||||
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
||||
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
@ -968,7 +968,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
||||
bos_token: Option<&'a str>,
|
||||
eos_token: Option<&'a str>,
|
||||
add_generation_prompt: bool,
|
||||
tools: Option<&'a str>,
|
||||
tools: Option<Vec<Tool>>,
|
||||
tools_prompt: Option<&'a str>,
|
||||
guideline: Option<&'a str>,
|
||||
}
|
||||
|
@ -2562,13 +2562,11 @@ fn prepare_chat_input(
|
||||
}
|
||||
|
||||
// if tools are set, apply the tool grammar and then the chat template
|
||||
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools, tool_choice)?;
|
||||
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools.clone(), tool_choice)?;
|
||||
let grammar = tool_grammar
|
||||
.as_ref()
|
||||
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||
let tools_grammar_prompt = tool_grammar
|
||||
.as_ref()
|
||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
|
||||
let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?;
|
||||
let tools_and_prompt: (Option<Vec<Tool>>, String) = (tools, tool_prompt.into());
|
||||
let inputs = infer.apply_chat_template(guideline, messages, Some(tools_and_prompt))?;
|
||||
Ok((inputs, grammar, tool_grammar))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user