feat: refactor tool logic to include notify_error in prompt and adjust typing

This commit is contained in:
drbh 2024-08-23 21:07:43 +00:00
parent 9ea34977ac
commit 1bf0e3b65c
6 changed files with 85 additions and 85 deletions

View File

@ -924,7 +924,7 @@
"tool_prompt": {
"type": "string",
"description": "A prompt to be appended before the tools",
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
"example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.",
"nullable": true
},
"tools": {

View File

@ -56,31 +56,33 @@ impl ChatTemplate {
&self,
guideline: Option<&str>,
mut messages: Vec<Message>,
tools_and_prompt: Option<(Option<Vec<Tool>>, String)>,
tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> {
// check if guideline is expected but not provided
if self.variables.contains("guideline") && guideline.is_none() {
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
}
let (tools, tool_prompt) = tools_and_prompt.unwrap_or_default();
if let Some(ref tools) = tools {
// check if the `tools` variable is used in the template
// if not, we need to append the tools to the last message
let text = if self.use_default_tool_template {
match serde_json::to_string(tools) {
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
Err(e) => return Err(InferError::ToolError(e.to_string())),
let tools = match tools_and_prompt {
Some((tools, tool_prompt)) => {
// check if the `tools` variable is used in the template
// if not, we need to append the tools to the last message
let text = if self.use_default_tool_template {
match serde_json::to_string(&tools) {
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
Err(e) => return Err(InferError::ToolError(e.to_string())),
}
} else {
// if the `tools` variable is used in the template, we just append the tool_prompt
format!("\n---\n{}", tool_prompt)
};
if let Some(last_message) = messages.last_mut() {
last_message.content.push(MessageChunk::Text { text });
}
} else {
// if the `tools` variable is used in the template, we just append the tool_prompt
format!("\n---\n{}", tool_prompt)
};
if let Some(last_message) = messages.last_mut() {
last_message.content.push(MessageChunk::Text { text });
Some(tools)
}
}
None => None,
};
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
@ -92,7 +94,6 @@ impl ChatTemplate {
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools,
tools_prompt: None,
})
.map_err(InferError::TemplateError)
}
@ -104,8 +105,7 @@ mod tests {
use crate::infer::chat_template::raise_exception;
use crate::infer::ChatTemplate;
use crate::{
ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage,
TokenizerConfigToken, Tool,
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
};
use minijinja::Environment;
@ -867,7 +867,7 @@ mod tests {
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((Some(tools), tool_prompt));
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(None, msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);

View File

@ -140,7 +140,7 @@ impl Infer {
&self,
guideline: Option<String>,
messages: Vec<Message>,
tools_and_prompt: Option<(Option<Vec<Tool>>, String)>,
tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> {
self.chat_template
.as_ref()

View File

@ -1,5 +1,8 @@
use crate::infer::InferError;
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
use crate::{
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
ToolType,
};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
@ -16,17 +19,42 @@ impl ToolGrammar {
}
pub fn apply(
tools: Option<Vec<Tool>>,
tools: Vec<Tool>,
tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> {
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
// if no tools are provided, we return None
let tools = match tools {
Some(tools) if !tools.is_empty() => tools,
_ => return Ok(None),
};
if tools.is_empty() {
return Ok((tools, None));
}
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
let mut tools = tools.clone();
// add the notify_error function to the tools
let notify_error = Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "notify_error".to_string(),
description: Some("Notify an error or issue".to_string()),
arguments: json!({
"type": "object",
"properties": {
"error": {
"type": "string",
"description": "The error or issue to notify"
},
"_name": {
"type": "string",
"const": "notify_error"
}
},
"required": ["error", "_name"]
}),
},
};
tools.push(notify_error);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
@ -35,27 +63,10 @@ impl ToolGrammar {
ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
ToolType::OneOf => tools.clone(),
ToolType::NoTool => return Ok((tools, None)),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
@ -105,17 +116,9 @@ impl ToolGrammar {
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
let tool_schema = JsonSchemaTool {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
@ -130,6 +133,6 @@ impl ToolGrammar {
},
};
Ok(Some(tools))
Ok((tools, Some(tool_schema)))
}
}

View File

@ -910,7 +910,7 @@ impl From<ToolTypeDeserializer> for ToolChoice {
}
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
pub struct Tools {
pub struct JsonSchemaTool {
#[serde(flatten)]
functions_map: FunctionsMap,
properties: Properties,
@ -969,7 +969,6 @@ pub(crate) struct ChatTemplateInputs<'a> {
eos_token: Option<&'a str>,
add_generation_prompt: bool,
tools: Option<Vec<Tool>>,
tools_prompt: Option<&'a str>,
guideline: Option<&'a str>,
}
@ -1207,7 +1206,6 @@ pub(crate) struct GenerateResponse {
pub(crate) struct ChatTokenizeResponse {
pub(crate) tokenize_response: TokenizeResponse,
pub(crate) templated_text: String,
pub(crate) using_tools: bool,
}
#[derive(Serialize, ToSchema)]

View File

@ -23,7 +23,7 @@ use crate::{
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
VertexResponse,
};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
@ -146,7 +146,7 @@ async fn get_chat_tokenize(
} = req;
let tool_prompt = tool_prompt.unwrap_or_default();
let (inputs, _grammar, using_tools) = prepare_chat_input(
let (inputs, _grammar, _using_tools) = prepare_chat_input(
&infer,
response_format,
tools,
@ -206,7 +206,6 @@ async fn get_chat_tokenize(
let resp = ChatTokenizeResponse {
tokenize_response: TokenizeResponse(tokens),
templated_text: input,
using_tools,
};
Ok((HeaderMap::new(), Json(resp)))
} else {
@ -2562,28 +2561,28 @@ fn prepare_chat_input(
return Ok((inputs, Some(format), false));
}
// if tools are set, apply the tool grammar and then the chat template
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools.clone(), tool_choice)?;
let grammar = tool_grammar
let (updated_tools, tool_schema) = ToolGrammar::apply(tools.unwrap().clone(), tool_choice)?;
let grammar = tool_schema
.as_ref()
.map(|t| GrammarType::Json(serde_json::json!(t)));
let tools_and_prompt: (Option<Vec<Tool>>, String) = (tools, tool_prompt.into());
let inputs = infer.apply_chat_template(guideline, messages, Some(tools_and_prompt))?;
Ok((inputs, grammar, tool_grammar.is_some()))
let inputs: String = infer.apply_chat_template(
guideline,
messages,
Some((updated_tools, tool_prompt.into())),
)?;
Ok((inputs, grammar, tool_schema.is_some()))
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::ChatTemplateVersions;
use crate::FunctionsMap;
use crate::HubTokenizerConfig;
use crate::Properties;
use crate::TokenizerConfigToken;
use crate::Tool;
use crate::Tools;
use serde_json::json;
@ -2595,26 +2594,26 @@ mod tests {
impl Backend for MockBackend {
fn schedule(
&self,
request: crate::validation::ValidGenerateRequest,
_request: crate::validation::ValidGenerateRequest,
) -> Result<
tokio_stream::wrappers::UnboundedReceiverStream<
Result<InferStreamResponse, InferError>,
>,
InferError,
> {
unimplemented!()
unimplemented!("Never called in this test");
}
fn health<'life0, 'async_trait>(
&'life0 self,
current_health: bool,
fn health<'a, 'async_trait>(
&'a self,
_current_health: bool,
) -> core::pin::Pin<
Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>,
>
where
'life0: 'async_trait,
'a: 'async_trait,
Self: 'async_trait,
{
unimplemented!()
unimplemented!("Never called in this test");
}
}
@ -2680,8 +2679,8 @@ mod tests {
);
assert!(result.is_ok());
let (inputs, grammar, using_tools) = result.unwrap();
let (inputs, _grammar, using_tools) = result.unwrap();
assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"_name\":{\"const\":\"notify_error\",\"type\":\"string\"},\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\",\"_name\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
}
}