diff --git a/docs/openapi.json b/docs/openapi.json index df21e19d..fd64a3ab 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -924,7 +924,7 @@ "tool_prompt": { "type": "string", "description": "A prompt to be appended before the tools", - "example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"", + "example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.", "nullable": true }, "tools": { diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 4a8141b4..bfa9421c 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -56,31 +56,33 @@ impl ChatTemplate { &self, guideline: Option<&str>, mut messages: Vec, - tools_and_prompt: Option<(Option>, String)>, + tools_and_prompt: Option<(Vec, String)>, ) -> Result { // check if guideline is expected but not provided if self.variables.contains("guideline") && guideline.is_none() { return Err(InferError::MissingTemplateVariable("guideline".to_string())); } - let (tools, tool_prompt) = tools_and_prompt.unwrap_or_default(); - - if let Some(ref tools) = tools { - // check if the `tools` variable is used in the template - // if not, we need to append the tools to the last message - let text = if self.use_default_tool_template { - match serde_json::to_string(tools) { - Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt), - Err(e) => return Err(InferError::ToolError(e.to_string())), + let tools = match tools_and_prompt { + Some((tools, tool_prompt)) => { + // check if the `tools` variable is used in the template + // if not, we need to append the tools to the last message + let text = if self.use_default_tool_template { + match serde_json::to_string(&tools) { + Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt), + Err(e) => return Err(InferError::ToolError(e.to_string())), + } + } else { + // if the `tools` variable is used in the template, we just append the tool_prompt + format!("\n---\n{}", tool_prompt) + }; + if let Some(last_message) = messages.last_mut() { + last_message.content.push(MessageChunk::Text { text }); } - } else { - // if the `tools` variable is used in the template, we just append the tool_prompt - format!("\n---\n{}", tool_prompt) - }; - if let Some(last_message) = messages.last_mut() { - last_message.content.push(MessageChunk::Text { text }); + Some(tools) } - } + None => None, + }; let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); @@ -92,7 +94,6 @@ impl ChatTemplate { eos_token: self.eos_token.as_deref(), add_generation_prompt: true, tools, - tools_prompt: None, }) .map_err(InferError::TemplateError) } @@ -104,8 +105,7 @@ mod tests { use crate::infer::chat_template::raise_exception; use crate::infer::ChatTemplate; use crate::{ - ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, - TokenizerConfigToken, Tool, + ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool, }; use minijinja::Environment; @@ -867,7 +867,7 @@ mod tests { let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools: Vec = serde_json::from_str(&tools_string).unwrap(); let tool_prompt = "This default prompt will be used".to_string(); - let tools_and_prompt = Some((Some(tools), tool_prompt)); + let tools_and_prompt = Some((tools, tool_prompt)); let result = ct.apply(None, msgs, tools_and_prompt); let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 5329dc72..81c0d38f 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -140,7 +140,7 @@ impl Infer { &self, guideline: Option, messages: Vec, - tools_and_prompt: Option<(Option>, String)>, + tools_and_prompt: Option<(Vec, String)>, ) -> Result { self.chat_template .as_ref() diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 05027f30..64d64d9b 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -1,5 +1,8 @@ use crate::infer::InferError; -use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools}; +use crate::{ + FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, + ToolType, +}; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -16,17 +19,42 @@ impl ToolGrammar { } pub fn apply( - tools: Option>, + tools: Vec, tool_choice: ToolChoice, - ) -> Result, InferError> { + ) -> Result<(Vec, Option), InferError> { // if no tools are provided, we return None - let tools = match tools { - Some(tools) if !tools.is_empty() => tools, - _ => return Ok(None), - }; + if tools.is_empty() { + return Ok((tools, None)); + } let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); + let mut tools = tools.clone(); + + // add the notify_error function to the tools + let notify_error = Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "notify_error".to_string(), + description: Some("Notify an error or issue".to_string()), + arguments: json!({ + "type": "object", + "properties": { + "error": { + "type": "string", + "description": "The error or issue to notify" + }, + "_name": { + "type": "string", + "const": "notify_error" + } + }, + "required": ["error", "_name"] + }), + }, + }; + tools.push(notify_error); + // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { ToolType::FunctionName(name) => { @@ -35,27 +63,10 @@ impl ToolGrammar { ToolType::Function { function } => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ToolType::OneOf => tools, - ToolType::NoTool => return Ok(None), + ToolType::OneOf => tools.clone(), + ToolType::NoTool => return Ok((tools, None)), }; - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - let functions: HashMap = tools_to_use .iter() .map(|tool| { @@ -105,17 +116,9 @@ impl ToolGrammar { (func.name, Value::Object(params)) }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) .collect(); - let tools = Tools { + let tool_schema = JsonSchemaTool { functions_map: FunctionsMap { functions }, properties: Properties { function: tools_to_use @@ -130,6 +133,6 @@ impl ToolGrammar { }, }; - Ok(Some(tools)) + Ok((tools, Some(tool_schema))) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index fbf23631..af77d436 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -910,7 +910,7 @@ impl From for ToolChoice { } #[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)] -pub struct Tools { +pub struct JsonSchemaTool { #[serde(flatten)] functions_map: FunctionsMap, properties: Properties, @@ -969,7 +969,6 @@ pub(crate) struct ChatTemplateInputs<'a> { eos_token: Option<&'a str>, add_generation_prompt: bool, tools: Option>, - tools_prompt: Option<&'a str>, guideline: Option<&'a str>, } @@ -1207,7 +1206,6 @@ pub(crate) struct GenerateResponse { pub(crate) struct ChatTokenizeResponse { pub(crate) tokenize_response: TokenizeResponse, pub(crate) templated_text: String, - pub(crate) using_tools: bool, } #[derive(Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 791165eb..5ccaa6f8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -23,7 +23,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -146,7 +146,7 @@ async fn get_chat_tokenize( } = req; let tool_prompt = tool_prompt.unwrap_or_default(); - let (inputs, _grammar, using_tools) = prepare_chat_input( + let (inputs, _grammar, _using_tools) = prepare_chat_input( &infer, response_format, tools, @@ -206,7 +206,6 @@ async fn get_chat_tokenize( let resp = ChatTokenizeResponse { tokenize_response: TokenizeResponse(tokens), templated_text: input, - using_tools, }; Ok((HeaderMap::new(), Json(resp))) } else { @@ -2562,28 +2561,28 @@ fn prepare_chat_input( return Ok((inputs, Some(format), false)); } - // if tools are set, apply the tool grammar and then the chat template - let tool_grammar: Option = ToolGrammar::apply(tools.clone(), tool_choice)?; - let grammar = tool_grammar + let (updated_tools, tool_schema) = ToolGrammar::apply(tools.unwrap().clone(), tool_choice)?; + + let grammar = tool_schema .as_ref() .map(|t| GrammarType::Json(serde_json::json!(t))); - let tools_and_prompt: (Option>, String) = (tools, tool_prompt.into()); - let inputs = infer.apply_chat_template(guideline, messages, Some(tools_and_prompt))?; - Ok((inputs, grammar, tool_grammar.is_some())) + + let inputs: String = infer.apply_chat_template( + guideline, + messages, + Some((updated_tools, tool_prompt.into())), + )?; + + Ok((inputs, grammar, tool_schema.is_some())) } #[cfg(test)] mod tests { - use std::collections::HashMap; - use super::*; use crate::ChatTemplateVersions; - use crate::FunctionsMap; use crate::HubTokenizerConfig; - use crate::Properties; use crate::TokenizerConfigToken; use crate::Tool; - use crate::Tools; use serde_json::json; @@ -2595,26 +2594,26 @@ mod tests { impl Backend for MockBackend { fn schedule( &self, - request: crate::validation::ValidGenerateRequest, + _request: crate::validation::ValidGenerateRequest, ) -> Result< tokio_stream::wrappers::UnboundedReceiverStream< Result, >, InferError, > { - unimplemented!() + unimplemented!("Never called in this test"); } - fn health<'life0, 'async_trait>( - &'life0 self, - current_health: bool, + fn health<'a, 'async_trait>( + &'a self, + _current_health: bool, ) -> core::pin::Pin< Box + core::marker::Send + 'async_trait>, > where - 'life0: 'async_trait, + 'a: 'async_trait, Self: 'async_trait, { - unimplemented!() + unimplemented!("Never called in this test"); } } @@ -2680,8 +2679,8 @@ mod tests { ); assert!(result.is_ok()); - let (inputs, grammar, using_tools) = result.unwrap(); + let (inputs, _grammar, using_tools) = result.unwrap(); assert_eq!(using_tools, true); - assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); + assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"_name\":{\"const\":\"notify_error\",\"type\":\"string\"},\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\",\"_name\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); } }