mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: refactor tool logic to include notify_error in prompt and adjust typing
This commit is contained in:
parent
9ea34977ac
commit
1bf0e3b65c
@ -924,7 +924,7 @@
|
||||
"tool_prompt": {
|
||||
"type": "string",
|
||||
"description": "A prompt to be appended before the tools",
|
||||
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
|
||||
"example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.",
|
||||
"nullable": true
|
||||
},
|
||||
"tools": {
|
||||
|
@ -56,31 +56,33 @@ impl ChatTemplate {
|
||||
&self,
|
||||
guideline: Option<&str>,
|
||||
mut messages: Vec<Message>,
|
||||
tools_and_prompt: Option<(Option<Vec<Tool>>, String)>,
|
||||
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
// check if guideline is expected but not provided
|
||||
if self.variables.contains("guideline") && guideline.is_none() {
|
||||
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
|
||||
}
|
||||
|
||||
let (tools, tool_prompt) = tools_and_prompt.unwrap_or_default();
|
||||
|
||||
if let Some(ref tools) = tools {
|
||||
// check if the `tools` variable is used in the template
|
||||
// if not, we need to append the tools to the last message
|
||||
let text = if self.use_default_tool_template {
|
||||
match serde_json::to_string(tools) {
|
||||
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
|
||||
Err(e) => return Err(InferError::ToolError(e.to_string())),
|
||||
let tools = match tools_and_prompt {
|
||||
Some((tools, tool_prompt)) => {
|
||||
// check if the `tools` variable is used in the template
|
||||
// if not, we need to append the tools to the last message
|
||||
let text = if self.use_default_tool_template {
|
||||
match serde_json::to_string(&tools) {
|
||||
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
|
||||
Err(e) => return Err(InferError::ToolError(e.to_string())),
|
||||
}
|
||||
} else {
|
||||
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||
format!("\n---\n{}", tool_prompt)
|
||||
};
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
last_message.content.push(MessageChunk::Text { text });
|
||||
}
|
||||
} else {
|
||||
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||
format!("\n---\n{}", tool_prompt)
|
||||
};
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
last_message.content.push(MessageChunk::Text { text });
|
||||
Some(tools)
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
|
||||
@ -92,7 +94,6 @@ impl ChatTemplate {
|
||||
eos_token: self.eos_token.as_deref(),
|
||||
add_generation_prompt: true,
|
||||
tools,
|
||||
tools_prompt: None,
|
||||
})
|
||||
.map_err(InferError::TemplateError)
|
||||
}
|
||||
@ -104,8 +105,7 @@ mod tests {
|
||||
use crate::infer::chat_template::raise_exception;
|
||||
use crate::infer::ChatTemplate;
|
||||
use crate::{
|
||||
ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage,
|
||||
TokenizerConfigToken, Tool,
|
||||
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
|
||||
};
|
||||
use minijinja::Environment;
|
||||
|
||||
@ -867,7 +867,7 @@ mod tests {
|
||||
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
||||
let tool_prompt = "This default prompt will be used".to_string();
|
||||
let tools_and_prompt = Some((Some(tools), tool_prompt));
|
||||
let tools_and_prompt = Some((tools, tool_prompt));
|
||||
let result = ct.apply(None, msgs, tools_and_prompt);
|
||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
|
||||
assert_eq!(result.unwrap(), expected);
|
||||
|
@ -140,7 +140,7 @@ impl Infer {
|
||||
&self,
|
||||
guideline: Option<String>,
|
||||
messages: Vec<Message>,
|
||||
tools_and_prompt: Option<(Option<Vec<Tool>>, String)>,
|
||||
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
self.chat_template
|
||||
.as_ref()
|
||||
|
@ -1,5 +1,8 @@
|
||||
use crate::infer::InferError;
|
||||
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
|
||||
use crate::{
|
||||
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
|
||||
ToolType,
|
||||
};
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -16,17 +19,42 @@ impl ToolGrammar {
|
||||
}
|
||||
|
||||
pub fn apply(
|
||||
tools: Option<Vec<Tool>>,
|
||||
tools: Vec<Tool>,
|
||||
tool_choice: ToolChoice,
|
||||
) -> Result<Option<Tools>, InferError> {
|
||||
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
|
||||
// if no tools are provided, we return None
|
||||
let tools = match tools {
|
||||
Some(tools) if !tools.is_empty() => tools,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
if tools.is_empty() {
|
||||
return Ok((tools, None));
|
||||
}
|
||||
|
||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||
|
||||
let mut tools = tools.clone();
|
||||
|
||||
// add the notify_error function to the tools
|
||||
let notify_error = Tool {
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
name: "notify_error".to_string(),
|
||||
description: Some("Notify an error or issue".to_string()),
|
||||
arguments: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"error": {
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
},
|
||||
"_name": {
|
||||
"type": "string",
|
||||
"const": "notify_error"
|
||||
}
|
||||
},
|
||||
"required": ["error", "_name"]
|
||||
}),
|
||||
},
|
||||
};
|
||||
tools.push(notify_error);
|
||||
|
||||
// if tools are provided and no tool_choice we default to the OneOf
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
@ -35,27 +63,10 @@ impl ToolGrammar {
|
||||
ToolType::Function { function } => {
|
||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||
}
|
||||
ToolType::OneOf => tools,
|
||||
ToolType::NoTool => return Ok(None),
|
||||
ToolType::OneOf => tools.clone(),
|
||||
ToolType::NoTool => return Ok((tools, None)),
|
||||
};
|
||||
|
||||
// adds the error notification function for LLM feedback if required
|
||||
let mut text_response_properties = Map::new();
|
||||
text_response_properties.insert(
|
||||
"error".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
}),
|
||||
);
|
||||
text_response_properties.insert(
|
||||
"_name".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"const": "notify_error"
|
||||
}),
|
||||
);
|
||||
|
||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
@ -105,17 +116,9 @@ impl ToolGrammar {
|
||||
|
||||
(func.name, Value::Object(params))
|
||||
})
|
||||
.chain([(
|
||||
"notify_error".to_string(),
|
||||
serde_json::json!({
|
||||
"properties": text_response_properties,
|
||||
"required": ["error", "_name"],
|
||||
"type": "object"
|
||||
}),
|
||||
)])
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
let tool_schema = JsonSchemaTool {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
@ -130,6 +133,6 @@ impl ToolGrammar {
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Some(tools))
|
||||
Ok((tools, Some(tool_schema)))
|
||||
}
|
||||
}
|
||||
|
@ -910,7 +910,7 @@ impl From<ToolTypeDeserializer> for ToolChoice {
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
|
||||
pub struct Tools {
|
||||
pub struct JsonSchemaTool {
|
||||
#[serde(flatten)]
|
||||
functions_map: FunctionsMap,
|
||||
properties: Properties,
|
||||
@ -969,7 +969,6 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
||||
eos_token: Option<&'a str>,
|
||||
add_generation_prompt: bool,
|
||||
tools: Option<Vec<Tool>>,
|
||||
tools_prompt: Option<&'a str>,
|
||||
guideline: Option<&'a str>,
|
||||
}
|
||||
|
||||
@ -1207,7 +1206,6 @@ pub(crate) struct GenerateResponse {
|
||||
pub(crate) struct ChatTokenizeResponse {
|
||||
pub(crate) tokenize_response: TokenizeResponse,
|
||||
pub(crate) templated_text: String,
|
||||
pub(crate) using_tools: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
|
@ -23,7 +23,7 @@ use crate::{
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
||||
VertexResponse,
|
||||
};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||
use async_stream::__private::AsyncStream;
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||
@ -146,7 +146,7 @@ async fn get_chat_tokenize(
|
||||
} = req;
|
||||
|
||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let (inputs, _grammar, using_tools) = prepare_chat_input(
|
||||
let (inputs, _grammar, _using_tools) = prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
@ -206,7 +206,6 @@ async fn get_chat_tokenize(
|
||||
let resp = ChatTokenizeResponse {
|
||||
tokenize_response: TokenizeResponse(tokens),
|
||||
templated_text: input,
|
||||
using_tools,
|
||||
};
|
||||
Ok((HeaderMap::new(), Json(resp)))
|
||||
} else {
|
||||
@ -2562,28 +2561,28 @@ fn prepare_chat_input(
|
||||
return Ok((inputs, Some(format), false));
|
||||
}
|
||||
|
||||
// if tools are set, apply the tool grammar and then the chat template
|
||||
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools.clone(), tool_choice)?;
|
||||
let grammar = tool_grammar
|
||||
let (updated_tools, tool_schema) = ToolGrammar::apply(tools.unwrap().clone(), tool_choice)?;
|
||||
|
||||
let grammar = tool_schema
|
||||
.as_ref()
|
||||
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||
let tools_and_prompt: (Option<Vec<Tool>>, String) = (tools, tool_prompt.into());
|
||||
let inputs = infer.apply_chat_template(guideline, messages, Some(tools_and_prompt))?;
|
||||
Ok((inputs, grammar, tool_grammar.is_some()))
|
||||
|
||||
let inputs: String = infer.apply_chat_template(
|
||||
guideline,
|
||||
messages,
|
||||
Some((updated_tools, tool_prompt.into())),
|
||||
)?;
|
||||
|
||||
Ok((inputs, grammar, tool_schema.is_some()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::*;
|
||||
use crate::ChatTemplateVersions;
|
||||
use crate::FunctionsMap;
|
||||
use crate::HubTokenizerConfig;
|
||||
use crate::Properties;
|
||||
use crate::TokenizerConfigToken;
|
||||
use crate::Tool;
|
||||
use crate::Tools;
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
@ -2595,26 +2594,26 @@ mod tests {
|
||||
impl Backend for MockBackend {
|
||||
fn schedule(
|
||||
&self,
|
||||
request: crate::validation::ValidGenerateRequest,
|
||||
_request: crate::validation::ValidGenerateRequest,
|
||||
) -> Result<
|
||||
tokio_stream::wrappers::UnboundedReceiverStream<
|
||||
Result<InferStreamResponse, InferError>,
|
||||
>,
|
||||
InferError,
|
||||
> {
|
||||
unimplemented!()
|
||||
unimplemented!("Never called in this test");
|
||||
}
|
||||
fn health<'life0, 'async_trait>(
|
||||
&'life0 self,
|
||||
current_health: bool,
|
||||
fn health<'a, 'async_trait>(
|
||||
&'a self,
|
||||
_current_health: bool,
|
||||
) -> core::pin::Pin<
|
||||
Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>,
|
||||
>
|
||||
where
|
||||
'life0: 'async_trait,
|
||||
'a: 'async_trait,
|
||||
Self: 'async_trait,
|
||||
{
|
||||
unimplemented!()
|
||||
unimplemented!("Never called in this test");
|
||||
}
|
||||
}
|
||||
|
||||
@ -2680,8 +2679,8 @@ mod tests {
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (inputs, grammar, using_tools) = result.unwrap();
|
||||
let (inputs, _grammar, using_tools) = result.unwrap();
|
||||
assert_eq!(using_tools, true);
|
||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"_name\":{\"const\":\"notify_error\",\"type\":\"string\"},\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\",\"_name\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user