mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
fix[router]: Fix tools not passed in chat template
Signed-off-by: GitHub <noreply@github.com>
This commit is contained in:
parent
30be188400
commit
2ee98c7c07
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2174,6 +2174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d"
|
checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -46,7 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
|||||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
] }
|
] }
|
||||||
minijinja = { version = "2.0.2" }
|
minijinja = { version = "2.0.2", features = ["json"] }
|
||||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
regex = "1.10.3"
|
regex = "1.10.3"
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{
|
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
|
||||||
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
|
|
||||||
};
|
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use minijinja_contrib::pycompat;
|
use minijinja_contrib::pycompat;
|
||||||
|
|
||||||
@ -32,6 +30,7 @@ impl ChatTemplate {
|
|||||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
let template_str = template.into_boxed_str();
|
let template_str = template.into_boxed_str();
|
||||||
env.add_function("raise_exception", raise_exception);
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
tracing::debug!("Loading template: {:#?}", template_str);
|
||||||
|
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
let template = Box::leak(env)
|
let template = Box::leak(env)
|
||||||
@ -42,6 +41,7 @@ impl ChatTemplate {
|
|||||||
let variables = template.undeclared_variables(true);
|
let variables = template.undeclared_variables(true);
|
||||||
// check if the `tools` variable is used in the template
|
// check if the `tools` variable is used in the template
|
||||||
let use_default_tool_template = !variables.contains("tools");
|
let use_default_tool_template = !variables.contains("tools");
|
||||||
|
tracing::debug!("Use default tool template: {}", use_default_tool_template);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
template,
|
template,
|
||||||
@ -56,36 +56,43 @@ impl ChatTemplate {
|
|||||||
&self,
|
&self,
|
||||||
guideline: Option<&str>,
|
guideline: Option<&str>,
|
||||||
mut messages: Vec<Message>,
|
mut messages: Vec<Message>,
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
tools_and_prompt: Option<(Option<Vec<Tool>>, String)>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
if self.use_default_tool_template {
|
|
||||||
if let Some(last_message) = messages.last_mut() {
|
|
||||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
|
||||||
last_message.content.push(MessageChunk::Text {
|
|
||||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
|
||||||
|
|
||||||
// check if guideline is expected but not provided
|
// check if guideline is expected but not provided
|
||||||
if self.variables.contains("guideline") && guideline.is_none() {
|
if self.variables.contains("guideline") && guideline.is_none() {
|
||||||
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
|
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
self.template
|
let (tools, tool_prompt) = tools_and_prompt.unwrap_or_default();
|
||||||
|
|
||||||
|
if tools.is_some() {
|
||||||
|
// check if the `tools` variable is used in the template
|
||||||
|
// if not, we need to append the tools to the last message
|
||||||
|
let text = if self.use_default_tool_template {
|
||||||
|
format!("\n---\n{:?}\n{}", tools, tool_prompt)
|
||||||
|
} else {
|
||||||
|
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||||
|
format!("\n---\n{}", tool_prompt)
|
||||||
|
};
|
||||||
|
if let Some(last_message) = messages.last_mut() {
|
||||||
|
last_message.content.push(MessageChunk::Text { text });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
|
|
||||||
|
return self
|
||||||
|
.template
|
||||||
.render(ChatTemplateInputs {
|
.render(ChatTemplateInputs {
|
||||||
guideline,
|
guideline,
|
||||||
messages,
|
messages,
|
||||||
bos_token: self.bos_token.as_deref(),
|
bos_token: self.bos_token.as_deref(),
|
||||||
eos_token: self.eos_token.as_deref(),
|
eos_token: self.eos_token.as_deref(),
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
tools: None,
|
tools: tools,
|
||||||
tools_prompt: None,
|
tools_prompt: None,
|
||||||
})
|
})
|
||||||
.map_err(InferError::TemplateError)
|
.map_err(InferError::TemplateError);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ mod chat_template;
|
|||||||
pub mod tool_grammar;
|
pub mod tool_grammar;
|
||||||
|
|
||||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||||
use crate::GrammarType;
|
use crate::Tool;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||||
Message, PrefillToken, Token,
|
Message, PrefillToken, Token,
|
||||||
@ -140,12 +140,12 @@ impl Infer {
|
|||||||
&self,
|
&self,
|
||||||
guideline: Option<String>,
|
guideline: Option<String>,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
tools_and_prompt: Option<(Option<Vec<Tool>>, String)>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
self.chat_template
|
self.chat_template
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
.apply(guideline.as_deref(), messages, grammar_with_prompt)
|
.apply(guideline.as_deref(), messages, tools_and_prompt)
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
||||||
tracing::error!("{e}");
|
tracing::error!("{e}");
|
||||||
|
@ -843,7 +843,7 @@ pub(crate) struct ChatRequest {
|
|||||||
#[serde(default = "default_tool_prompt")]
|
#[serde(default = "default_tool_prompt")]
|
||||||
#[schema(
|
#[schema(
|
||||||
nullable = true,
|
nullable = true,
|
||||||
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
|
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||||||
)]
|
)]
|
||||||
pub tool_prompt: Option<String>,
|
pub tool_prompt: Option<String>,
|
||||||
|
|
||||||
@ -867,7 +867,7 @@ pub(crate) struct ChatRequest {
|
|||||||
|
|
||||||
fn default_tool_prompt() -> Option<String> {
|
fn default_tool_prompt() -> Option<String> {
|
||||||
Some(
|
Some(
|
||||||
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -968,7 +968,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
bos_token: Option<&'a str>,
|
bos_token: Option<&'a str>,
|
||||||
eos_token: Option<&'a str>,
|
eos_token: Option<&'a str>,
|
||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
tools: Option<&'a str>,
|
tools: Option<Vec<Tool>>,
|
||||||
tools_prompt: Option<&'a str>,
|
tools_prompt: Option<&'a str>,
|
||||||
guideline: Option<&'a str>,
|
guideline: Option<&'a str>,
|
||||||
}
|
}
|
||||||
|
@ -2562,13 +2562,11 @@ fn prepare_chat_input(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// if tools are set, apply the tool grammar and then the chat template
|
// if tools are set, apply the tool grammar and then the chat template
|
||||||
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools, tool_choice)?;
|
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools.clone(), tool_choice)?;
|
||||||
let grammar = tool_grammar
|
let grammar = tool_grammar
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||||
let tools_grammar_prompt = tool_grammar
|
let tools_and_prompt: (Option<Vec<Tool>>, String) = (tools, tool_prompt.into());
|
||||||
.as_ref()
|
let inputs = infer.apply_chat_template(guideline, messages, Some(tools_and_prompt))?;
|
||||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
|
|
||||||
let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?;
|
|
||||||
Ok((inputs, grammar, tool_grammar))
|
Ok((inputs, grammar, tool_grammar))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user