From f53c8059e9ea14059279334ba87a51221bfa3fff Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 14 Oct 2024 17:40:45 +0000 Subject: [PATCH] feat: improve, simplify and rename tool choice struct add required support and refactor --- docs/openapi.json | 78 +++++++++++---------- router/src/infer/tool_grammar.rs | 57 ++++++++-------- router/src/lib.rs | 113 ++++++++++++++++++++----------- router/src/server.rs | 9 ++- 4 files changed, 144 insertions(+), 113 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 903f7426..06c1f144 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -921,6 +921,42 @@ } } }, + "ChatCompletionToolChoiceOption": { + "oneOf": [ + { + "type": "string", + "description": "Means the model can pick between generating a message or calling one or more tools.", + "enum": [ + "auto" + ] + }, + { + "type": "string", + "description": "Means the model will not call any tool and instead generates a message.", + "enum": [ + "none" + ] + }, + { + "type": "string", + "description": "Means the model must call one or more tools.", + "enum": [ + "required" + ] + }, + { + "type": "object", + "required": [ + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/FunctionName" + } + } + } + ] + }, "ChatCompletionTopLogprob": { "type": "object", "required": [ @@ -1055,9 +1091,10 @@ "tool_choice": { "allOf": [ { - "$ref": "#/components/schemas/ToolChoice" + "$ref": "#/components/schemas/ChatCompletionToolChoiceOption" } ], + "default": "null", "nullable": true }, "tool_prompt": { @@ -2234,45 +2271,6 @@ } } }, - "ToolChoice": { - "allOf": [ - { - "$ref": "#/components/schemas/ToolType" - } - ], - "nullable": true - }, - "ToolType": { - "oneOf": [ - { - "type": "string", - "description": "Means the model can pick between generating a message or calling one or more tools.", - "enum": [ - "auto" - ] - }, - { - "type": "string", - "description": "Means the model will not call any tool and instead generates a message.", - "enum": [ - "none" - ] - }, - { - "type": "object", - "required": [ - "function" - ], - "properties": { - "function": { - "$ref": "#/components/schemas/FunctionName" - } - } - } - ], - "description": "Controls which (if any) tool is called by the model.", - "example": "auto" - }, "Url": { "type": "object", "required": [ diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index f86205fb..b9070812 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -1,7 +1,7 @@ use crate::infer::InferError; use crate::{ - FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, - ToolType, + ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, + Properties, Tool, }; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -20,44 +20,47 @@ impl ToolGrammar { pub fn apply( tools: Vec, - tool_choice: ToolChoice, + tool_choice: ChatCompletionToolChoiceOption, ) -> Result<(Vec, Option), InferError> { // if no tools are provided, we return None if tools.is_empty() { return Ok((tools, None)); } - let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); - let mut tools = tools.clone(); - // add the no_tool function to the tools - let no_tool = Tool { - r#type: "function".to_string(), - function: FunctionDefinition { - name: "no_tool".to_string(), - description: Some("Open ened response with no specific tool selected".to_string()), - arguments: json!({ - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The response content", - } - }, - "required": ["content"] - }), - }, - }; - tools.push(no_tool); + // add the no_tool function to the tools as long as we are not required to use a specific tool + if tool_choice != ChatCompletionToolChoiceOption::Required { + let no_tool = Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "no_tool".to_string(), + description: Some( + "Open ended response with no specific tool selected".to_string(), + ), + arguments: json!({ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The response content", + } + }, + "required": ["content"] + }), + }, + }; + tools.push(no_tool); + } // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { - ToolType::Function(function) => { + ChatCompletionToolChoiceOption::Function(function) => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ToolType::OneOf => tools.clone(), - ToolType::NoTool => return Ok((tools, None)), + ChatCompletionToolChoiceOption::Required => tools.clone(), + ChatCompletionToolChoiceOption::Auto => tools.clone(), + ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)), }; let functions: HashMap = tools_to_use diff --git a/router/src/lib.rs b/router/src/lib.rs index 6ecc6b39..15d202a8 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -892,8 +892,8 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] - #[schema(nullable = true, example = "null")] - pub tool_choice: ToolChoice, + #[schema(nullable = true, default = "null", example = "null")] + pub tool_choice: Option, /// Response format constraints for the generation. /// @@ -949,8 +949,18 @@ impl ChatRequest { let (inputs, grammar, using_tools) = prepare_chat_input( infer, response_format, - tools, - tool_choice, + tools.clone(), + // unwrap or default (use "auto" if tools are present, and "none" if not) + tool_choice.map_or_else( + || { + if tools.is_some() { + ChatCompletionToolChoiceOption::Auto + } else { + ChatCompletionToolChoiceOption::NoTool + } + }, + |t| t, + ), &tool_prompt, guideline, messages, @@ -999,22 +1009,6 @@ pub fn default_tool_prompt() -> String { "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() } -#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] -#[schema(example = "auto")] -/// Controls which (if any) tool is called by the model. -pub enum ToolType { - /// Means the model can pick between generating a message or calling one or more tools. - #[schema(rename = "auto")] - OneOf, - /// Means the model will not call any tool and instead generates a message. - #[schema(rename = "none")] - NoTool, - /// Forces the model to call a specific tool. - #[schema(rename = "function")] - #[serde(alias = "function")] - Function(FunctionName), -} - #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] #[serde(tag = "type")] pub enum TypedChoice { @@ -1027,29 +1021,59 @@ pub struct FunctionName { pub name: String, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)] #[serde(from = "ToolTypeDeserializer")] -pub struct ToolChoice(pub Option); - -#[derive(Deserialize)] -#[serde(untagged)] -enum ToolTypeDeserializer { - Null, - String(String), - ToolType(TypedChoice), +pub enum ChatCompletionToolChoiceOption { + /// Means the model can pick between generating a message or calling one or more tools. + #[schema(rename = "auto")] + Auto, + /// Means the model will not call any tool and instead generates a message. + #[schema(rename = "none")] + #[default] + NoTool, + /// Means the model must call one or more tools. + #[schema(rename = "required")] + Required, + /// Forces the model to call a specific tool. + #[schema(rename = "function")] + #[serde(alias = "function")] + Function(FunctionName), } -impl From for ToolChoice { +#[derive(Deserialize, ToSchema)] +#[serde(untagged)] +/// Controls which (if any) tool is called by the model. +/// - `none` means the model will not call any tool and instead generates a message. +/// - `auto` means the model can pick between generating a message or calling one or more tools. +/// - `required` means the model must call one or more tools. +/// - Specifying a particular tool via `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}` forces the model to call that tool. +/// +/// `none` is the default when no tools are present. `auto` is the default if tools are present." +enum ToolTypeDeserializer { + /// `none` means the model will not call any tool and instead generates a message. + Null, + + /// `auto` means the model can pick between generating a message or calling one or more tools. + #[schema(example = "auto")] + String(String), + + /// Specifying a particular tool forces the model to call that tool, with structured function details. + #[schema(example = r#"{"type": "function", "function": {"name": "my_function"}}"#)] + TypedChoice(TypedChoice), +} + +impl From for ChatCompletionToolChoiceOption { fn from(value: ToolTypeDeserializer) -> Self { match value { - ToolTypeDeserializer::Null => ToolChoice(None), + ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool, ToolTypeDeserializer::String(s) => match s.as_str() { - "none" => ToolChoice(Some(ToolType::NoTool)), - "auto" => ToolChoice(Some(ToolType::OneOf)), - _ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))), + "none" => ChatCompletionToolChoiceOption::NoTool, + "auto" => ChatCompletionToolChoiceOption::Auto, + "required" => ChatCompletionToolChoiceOption::Required, + _ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }), }, - ToolTypeDeserializer::ToolType(TypedChoice::Function { function }) => { - ToolChoice(Some(ToolType::Function(function))) + ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => { + ChatCompletionToolChoiceOption::Function(function) } } } @@ -1661,20 +1685,27 @@ mod tests { fn tool_choice_formats() { #[derive(Deserialize)] struct TestRequest { - tool_choice: ToolChoice, + tool_choice: ChatCompletionToolChoiceOption, } let none = r#"{"tool_choice":"none"}"#; let de_none: TestRequest = serde_json::from_str(none).unwrap(); - assert_eq!(de_none.tool_choice, ToolChoice(Some(ToolType::NoTool))); + assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool); let auto = r#"{"tool_choice":"auto"}"#; let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); - assert_eq!(de_auto.tool_choice, ToolChoice(Some(ToolType::OneOf))); + assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto); - let ref_choice = ToolChoice(Some(ToolType::Function(FunctionName { + let auto = r#"{"tool_choice":"required"}"#; + let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); + assert_eq!( + de_auto.tool_choice, + ChatCompletionToolChoiceOption::Required + ); + + let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName { name: "myfn".to_string(), - }))); + }); let named = r#"{"tool_choice":"myfn"}"#; let de_named: TestRequest = serde_json::from_str(named).unwrap(); diff --git a/router/src/server.rs b/router/src/server.rs index 863607b1..26a43f0a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -28,7 +28,7 @@ use crate::{ ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall}; use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; @@ -1551,12 +1551,11 @@ GrammarType, Usage, StreamOptions, DeltaToolCall, -ToolType, Tool, ToolCall, Function, FunctionDefinition, -ToolChoice, +ChatCompletionToolChoiceOption, ModelInfo, ) ), @@ -2522,7 +2521,7 @@ pub(crate) fn prepare_chat_input( infer: &Infer, response_format: Option, tools: Option>, - tool_choice: ToolChoice, + tool_choice: ChatCompletionToolChoiceOption, tool_prompt: &str, guideline: Option, messages: Vec, @@ -2660,7 +2659,7 @@ mod tests { &infer, response_format, tools, - ToolChoice(None), + ChatCompletionToolChoiceOption::Auto, tool_prompt, guideline, messages,