feat: refactor tool logic to include notify_error in prompt and adjust typing

This commit is contained in:
drbh 2024-08-23 21:07:43 +00:00
parent 9ea34977ac
commit 1bf0e3b65c
6 changed files with 85 additions and 85 deletions

View File

@ -924,7 +924,7 @@
"tool_prompt": { "tool_prompt": {
"type": "string", "type": "string",
"description": "A prompt to be appended before the tools", "description": "A prompt to be appended before the tools",
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"", "example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.",
"nullable": true "nullable": true
}, },
"tools": { "tools": {

View File

@ -56,20 +56,19 @@ impl ChatTemplate {
&self, &self,
guideline: Option<&str>, guideline: Option<&str>,
mut messages: Vec<Message>, mut messages: Vec<Message>,
tools_and_prompt: Option<(Option<Vec<Tool>>, String)>, tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
// check if guideline is expected but not provided // check if guideline is expected but not provided
if self.variables.contains("guideline") && guideline.is_none() { if self.variables.contains("guideline") && guideline.is_none() {
return Err(InferError::MissingTemplateVariable("guideline".to_string())); return Err(InferError::MissingTemplateVariable("guideline".to_string()));
} }
let (tools, tool_prompt) = tools_and_prompt.unwrap_or_default(); let tools = match tools_and_prompt {
Some((tools, tool_prompt)) => {
if let Some(ref tools) = tools {
// check if the `tools` variable is used in the template // check if the `tools` variable is used in the template
// if not, we need to append the tools to the last message // if not, we need to append the tools to the last message
let text = if self.use_default_tool_template { let text = if self.use_default_tool_template {
match serde_json::to_string(tools) { match serde_json::to_string(&tools) {
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt), Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
Err(e) => return Err(InferError::ToolError(e.to_string())), Err(e) => return Err(InferError::ToolError(e.to_string())),
} }
@ -80,7 +79,10 @@ impl ChatTemplate {
if let Some(last_message) = messages.last_mut() { if let Some(last_message) = messages.last_mut() {
last_message.content.push(MessageChunk::Text { text }); last_message.content.push(MessageChunk::Text { text });
} }
Some(tools)
} }
None => None,
};
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect(); let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
@ -92,7 +94,6 @@ impl ChatTemplate {
eos_token: self.eos_token.as_deref(), eos_token: self.eos_token.as_deref(),
add_generation_prompt: true, add_generation_prompt: true,
tools, tools,
tools_prompt: None,
}) })
.map_err(InferError::TemplateError) .map_err(InferError::TemplateError)
} }
@ -104,8 +105,7 @@ mod tests {
use crate::infer::chat_template::raise_exception; use crate::infer::chat_template::raise_exception;
use crate::infer::ChatTemplate; use crate::infer::ChatTemplate;
use crate::{ use crate::{
ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
TokenizerConfigToken, Tool,
}; };
use minijinja::Environment; use minijinja::Environment;
@ -867,7 +867,7 @@ mod tests {
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap(); let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string(); let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((Some(tools), tool_prompt)); let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(None, msgs, tools_and_prompt); let result = ct.apply(None, msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);

View File

@ -140,7 +140,7 @@ impl Infer {
&self, &self,
guideline: Option<String>, guideline: Option<String>,
messages: Vec<Message>, messages: Vec<Message>,
tools_and_prompt: Option<(Option<Vec<Tool>>, String)>, tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
self.chat_template self.chat_template
.as_ref() .as_ref()

View File

@ -1,5 +1,8 @@
use crate::infer::InferError; use crate::infer::InferError;
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools}; use crate::{
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
ToolType,
};
use serde_json::{json, Map, Value}; use serde_json::{json, Map, Value};
use std::collections::HashMap; use std::collections::HashMap;
@ -16,17 +19,42 @@ impl ToolGrammar {
} }
pub fn apply( pub fn apply(
tools: Option<Vec<Tool>>, tools: Vec<Tool>,
tool_choice: ToolChoice, tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> { ) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
// if no tools are provided, we return None // if no tools are provided, we return None
let tools = match tools { if tools.is_empty() {
Some(tools) if !tools.is_empty() => tools, return Ok((tools, None));
_ => return Ok(None), }
};
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
let mut tools = tools.clone();
// add the notify_error function to the tools
let notify_error = Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "notify_error".to_string(),
description: Some("Notify an error or issue".to_string()),
arguments: json!({
"type": "object",
"properties": {
"error": {
"type": "string",
"description": "The error or issue to notify"
},
"_name": {
"type": "string",
"const": "notify_error"
}
},
"required": ["error", "_name"]
}),
},
};
tools.push(notify_error);
// if tools are provided and no tool_choice we default to the OneOf // if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice { let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => { ToolType::FunctionName(name) => {
@ -35,27 +63,10 @@ impl ToolGrammar {
ToolType::Function { function } => { ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?] vec![Self::find_tool_by_name(&tools, &function.name)?]
} }
ToolType::OneOf => tools, ToolType::OneOf => tools.clone(),
ToolType::NoTool => return Ok(None), ToolType::NoTool => return Ok((tools, None)),
}; };
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter() .iter()
.map(|tool| { .map(|tool| {
@ -105,17 +116,9 @@ impl ToolGrammar {
(func.name, Value::Object(params)) (func.name, Value::Object(params))
}) })
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect(); .collect();
let tools = Tools { let tool_schema = JsonSchemaTool {
functions_map: FunctionsMap { functions }, functions_map: FunctionsMap { functions },
properties: Properties { properties: Properties {
function: tools_to_use function: tools_to_use
@ -130,6 +133,6 @@ impl ToolGrammar {
}, },
}; };
Ok(Some(tools)) Ok((tools, Some(tool_schema)))
} }
} }

View File

@ -910,7 +910,7 @@ impl From<ToolTypeDeserializer> for ToolChoice {
} }
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)] #[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
pub struct Tools { pub struct JsonSchemaTool {
#[serde(flatten)] #[serde(flatten)]
functions_map: FunctionsMap, functions_map: FunctionsMap,
properties: Properties, properties: Properties,
@ -969,7 +969,6 @@ pub(crate) struct ChatTemplateInputs<'a> {
eos_token: Option<&'a str>, eos_token: Option<&'a str>,
add_generation_prompt: bool, add_generation_prompt: bool,
tools: Option<Vec<Tool>>, tools: Option<Vec<Tool>>,
tools_prompt: Option<&'a str>,
guideline: Option<&'a str>, guideline: Option<&'a str>,
} }
@ -1207,7 +1206,6 @@ pub(crate) struct GenerateResponse {
pub(crate) struct ChatTokenizeResponse { pub(crate) struct ChatTokenizeResponse {
pub(crate) tokenize_response: TokenizeResponse, pub(crate) tokenize_response: TokenizeResponse,
pub(crate) templated_text: String, pub(crate) templated_text: String,
pub(crate) using_tools: bool,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]

View File

@ -23,7 +23,7 @@ use crate::{
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
VertexResponse, VertexResponse,
}; };
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
@ -146,7 +146,7 @@ async fn get_chat_tokenize(
} = req; } = req;
let tool_prompt = tool_prompt.unwrap_or_default(); let tool_prompt = tool_prompt.unwrap_or_default();
let (inputs, _grammar, using_tools) = prepare_chat_input( let (inputs, _grammar, _using_tools) = prepare_chat_input(
&infer, &infer,
response_format, response_format,
tools, tools,
@ -206,7 +206,6 @@ async fn get_chat_tokenize(
let resp = ChatTokenizeResponse { let resp = ChatTokenizeResponse {
tokenize_response: TokenizeResponse(tokens), tokenize_response: TokenizeResponse(tokens),
templated_text: input, templated_text: input,
using_tools,
}; };
Ok((HeaderMap::new(), Json(resp))) Ok((HeaderMap::new(), Json(resp)))
} else { } else {
@ -2562,28 +2561,28 @@ fn prepare_chat_input(
return Ok((inputs, Some(format), false)); return Ok((inputs, Some(format), false));
} }
// if tools are set, apply the tool grammar and then the chat template let (updated_tools, tool_schema) = ToolGrammar::apply(tools.unwrap().clone(), tool_choice)?;
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools.clone(), tool_choice)?;
let grammar = tool_grammar let grammar = tool_schema
.as_ref() .as_ref()
.map(|t| GrammarType::Json(serde_json::json!(t))); .map(|t| GrammarType::Json(serde_json::json!(t)));
let tools_and_prompt: (Option<Vec<Tool>>, String) = (tools, tool_prompt.into());
let inputs = infer.apply_chat_template(guideline, messages, Some(tools_and_prompt))?; let inputs: String = infer.apply_chat_template(
Ok((inputs, grammar, tool_grammar.is_some())) guideline,
messages,
Some((updated_tools, tool_prompt.into())),
)?;
Ok((inputs, grammar, tool_schema.is_some()))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::HashMap;
use super::*; use super::*;
use crate::ChatTemplateVersions; use crate::ChatTemplateVersions;
use crate::FunctionsMap;
use crate::HubTokenizerConfig; use crate::HubTokenizerConfig;
use crate::Properties;
use crate::TokenizerConfigToken; use crate::TokenizerConfigToken;
use crate::Tool; use crate::Tool;
use crate::Tools;
use serde_json::json; use serde_json::json;
@ -2595,26 +2594,26 @@ mod tests {
impl Backend for MockBackend { impl Backend for MockBackend {
fn schedule( fn schedule(
&self, &self,
request: crate::validation::ValidGenerateRequest, _request: crate::validation::ValidGenerateRequest,
) -> Result< ) -> Result<
tokio_stream::wrappers::UnboundedReceiverStream< tokio_stream::wrappers::UnboundedReceiverStream<
Result<InferStreamResponse, InferError>, Result<InferStreamResponse, InferError>,
>, >,
InferError, InferError,
> { > {
unimplemented!() unimplemented!("Never called in this test");
} }
fn health<'life0, 'async_trait>( fn health<'a, 'async_trait>(
&'life0 self, &'a self,
current_health: bool, _current_health: bool,
) -> core::pin::Pin< ) -> core::pin::Pin<
Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>, Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>,
> >
where where
'life0: 'async_trait, 'a: 'async_trait,
Self: 'async_trait, Self: 'async_trait,
{ {
unimplemented!() unimplemented!("Never called in this test");
} }
} }
@ -2680,8 +2679,8 @@ mod tests {
); );
assert!(result.is_ok()); assert!(result.is_ok());
let (inputs, grammar, using_tools) = result.unwrap(); let (inputs, _grammar, using_tools) = result.unwrap();
assert_eq!(using_tools, true); assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"_name\":{\"const\":\"notify_error\",\"type\":\"string\"},\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\",\"_name\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
} }
} }