feat: improve, simplify and rename tool choice struct add required support and refactor

This commit is contained in:
David Holtz 2024-10-14 17:40:45 +00:00 committed by drbh
parent 209f841767
commit f53c8059e9
4 changed files with 144 additions and 113 deletions

View File

@ -921,6 +921,42 @@
} }
} }
}, },
"ChatCompletionToolChoiceOption": {
"oneOf": [
{
"type": "string",
"description": "Means the model can pick between generating a message or calling one or more tools.",
"enum": [
"auto"
]
},
{
"type": "string",
"description": "Means the model will not call any tool and instead generates a message.",
"enum": [
"none"
]
},
{
"type": "string",
"description": "Means the model must call one or more tools.",
"enum": [
"required"
]
},
{
"type": "object",
"required": [
"function"
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
}
]
},
"ChatCompletionTopLogprob": { "ChatCompletionTopLogprob": {
"type": "object", "type": "object",
"required": [ "required": [
@ -1055,9 +1091,10 @@
"tool_choice": { "tool_choice": {
"allOf": [ "allOf": [
{ {
"$ref": "#/components/schemas/ToolChoice" "$ref": "#/components/schemas/ChatCompletionToolChoiceOption"
} }
], ],
"default": "null",
"nullable": true "nullable": true
}, },
"tool_prompt": { "tool_prompt": {
@ -2234,45 +2271,6 @@
} }
} }
}, },
"ToolChoice": {
"allOf": [
{
"$ref": "#/components/schemas/ToolType"
}
],
"nullable": true
},
"ToolType": {
"oneOf": [
{
"type": "string",
"description": "Means the model can pick between generating a message or calling one or more tools.",
"enum": [
"auto"
]
},
{
"type": "string",
"description": "Means the model will not call any tool and instead generates a message.",
"enum": [
"none"
]
},
{
"type": "object",
"required": [
"function"
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
}
],
"description": "Controls which (if any) tool is called by the model.",
"example": "auto"
},
"Url": { "Url": {
"type": "object", "type": "object",
"required": [ "required": [

View File

@ -1,7 +1,7 @@
use crate::infer::InferError; use crate::infer::InferError;
use crate::{ use crate::{
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool,
ToolType, Properties, Tool,
}; };
use serde_json::{json, Map, Value}; use serde_json::{json, Map, Value};
use std::collections::HashMap; use std::collections::HashMap;
@ -20,44 +20,47 @@ impl ToolGrammar {
pub fn apply( pub fn apply(
tools: Vec<Tool>, tools: Vec<Tool>,
tool_choice: ToolChoice, tool_choice: ChatCompletionToolChoiceOption,
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> { ) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
// if no tools are provided, we return None // if no tools are provided, we return None
if tools.is_empty() { if tools.is_empty() {
return Ok((tools, None)); return Ok((tools, None));
} }
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
let mut tools = tools.clone(); let mut tools = tools.clone();
// add the no_tool function to the tools // add the no_tool function to the tools as long as we are not required to use a specific tool
let no_tool = Tool { if tool_choice != ChatCompletionToolChoiceOption::Required {
r#type: "function".to_string(), let no_tool = Tool {
function: FunctionDefinition { r#type: "function".to_string(),
name: "no_tool".to_string(), function: FunctionDefinition {
description: Some("Open ened response with no specific tool selected".to_string()), name: "no_tool".to_string(),
arguments: json!({ description: Some(
"type": "object", "Open ended response with no specific tool selected".to_string(),
"properties": { ),
"content": { arguments: json!({
"type": "string", "type": "object",
"description": "The response content", "properties": {
} "content": {
}, "type": "string",
"required": ["content"] "description": "The response content",
}), }
}, },
}; "required": ["content"]
tools.push(no_tool); }),
},
};
tools.push(no_tool);
}
// if tools are provided and no tool_choice we default to the OneOf // if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice { let tools_to_use = match tool_choice {
ToolType::Function(function) => { ChatCompletionToolChoiceOption::Function(function) => {
vec![Self::find_tool_by_name(&tools, &function.name)?] vec![Self::find_tool_by_name(&tools, &function.name)?]
} }
ToolType::OneOf => tools.clone(), ChatCompletionToolChoiceOption::Required => tools.clone(),
ToolType::NoTool => return Ok((tools, None)), ChatCompletionToolChoiceOption::Auto => tools.clone(),
ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)),
}; };
let functions: HashMap<String, serde_json::Value> = tools_to_use let functions: HashMap<String, serde_json::Value> = tools_to_use

View File

@ -892,8 +892,8 @@ pub(crate) struct ChatRequest {
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub tool_choice: ToolChoice, pub tool_choice: Option<ChatCompletionToolChoiceOption>,
/// Response format constraints for the generation. /// Response format constraints for the generation.
/// ///
@ -949,8 +949,18 @@ impl ChatRequest {
let (inputs, grammar, using_tools) = prepare_chat_input( let (inputs, grammar, using_tools) = prepare_chat_input(
infer, infer,
response_format, response_format,
tools, tools.clone(),
tool_choice, // unwrap or default (use "auto" if tools are present, and "none" if not)
tool_choice.map_or_else(
|| {
if tools.is_some() {
ChatCompletionToolChoiceOption::Auto
} else {
ChatCompletionToolChoiceOption::NoTool
}
},
|t| t,
),
&tool_prompt, &tool_prompt,
guideline, guideline,
messages, messages,
@ -999,22 +1009,6 @@ pub fn default_tool_prompt() -> String {
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
} }
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
#[schema(example = "auto")]
/// Controls which (if any) tool is called by the model.
pub enum ToolType {
/// Means the model can pick between generating a message or calling one or more tools.
#[schema(rename = "auto")]
OneOf,
/// Means the model will not call any tool and instead generates a message.
#[schema(rename = "none")]
NoTool,
/// Forces the model to call a specific tool.
#[schema(rename = "function")]
#[serde(alias = "function")]
Function(FunctionName),
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum TypedChoice { pub enum TypedChoice {
@ -1027,29 +1021,59 @@ pub struct FunctionName {
pub name: String, pub name: String,
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
#[serde(from = "ToolTypeDeserializer")] #[serde(from = "ToolTypeDeserializer")]
pub struct ToolChoice(pub Option<ToolType>); pub enum ChatCompletionToolChoiceOption {
/// Means the model can pick between generating a message or calling one or more tools.
#[derive(Deserialize)] #[schema(rename = "auto")]
#[serde(untagged)] Auto,
enum ToolTypeDeserializer { /// Means the model will not call any tool and instead generates a message.
Null, #[schema(rename = "none")]
String(String), #[default]
ToolType(TypedChoice), NoTool,
/// Means the model must call one or more tools.
#[schema(rename = "required")]
Required,
/// Forces the model to call a specific tool.
#[schema(rename = "function")]
#[serde(alias = "function")]
Function(FunctionName),
} }
impl From<ToolTypeDeserializer> for ToolChoice { #[derive(Deserialize, ToSchema)]
#[serde(untagged)]
/// Controls which (if any) tool is called by the model.
/// - `none` means the model will not call any tool and instead generates a message.
/// - `auto` means the model can pick between generating a message or calling one or more tools.
/// - `required` means the model must call one or more tools.
/// - Specifying a particular tool via `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}` forces the model to call that tool.
///
/// `none` is the default when no tools are present. `auto` is the default if tools are present."
enum ToolTypeDeserializer {
/// `none` means the model will not call any tool and instead generates a message.
Null,
/// `auto` means the model can pick between generating a message or calling one or more tools.
#[schema(example = "auto")]
String(String),
/// Specifying a particular tool forces the model to call that tool, with structured function details.
#[schema(example = r#"{"type": "function", "function": {"name": "my_function"}}"#)]
TypedChoice(TypedChoice),
}
impl From<ToolTypeDeserializer> for ChatCompletionToolChoiceOption {
fn from(value: ToolTypeDeserializer) -> Self { fn from(value: ToolTypeDeserializer) -> Self {
match value { match value {
ToolTypeDeserializer::Null => ToolChoice(None), ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool,
ToolTypeDeserializer::String(s) => match s.as_str() { ToolTypeDeserializer::String(s) => match s.as_str() {
"none" => ToolChoice(Some(ToolType::NoTool)), "none" => ChatCompletionToolChoiceOption::NoTool,
"auto" => ToolChoice(Some(ToolType::OneOf)), "auto" => ChatCompletionToolChoiceOption::Auto,
_ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))), "required" => ChatCompletionToolChoiceOption::Required,
_ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }),
}, },
ToolTypeDeserializer::ToolType(TypedChoice::Function { function }) => { ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => {
ToolChoice(Some(ToolType::Function(function))) ChatCompletionToolChoiceOption::Function(function)
} }
} }
} }
@ -1661,20 +1685,27 @@ mod tests {
fn tool_choice_formats() { fn tool_choice_formats() {
#[derive(Deserialize)] #[derive(Deserialize)]
struct TestRequest { struct TestRequest {
tool_choice: ToolChoice, tool_choice: ChatCompletionToolChoiceOption,
} }
let none = r#"{"tool_choice":"none"}"#; let none = r#"{"tool_choice":"none"}"#;
let de_none: TestRequest = serde_json::from_str(none).unwrap(); let de_none: TestRequest = serde_json::from_str(none).unwrap();
assert_eq!(de_none.tool_choice, ToolChoice(Some(ToolType::NoTool))); assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool);
let auto = r#"{"tool_choice":"auto"}"#; let auto = r#"{"tool_choice":"auto"}"#;
let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
assert_eq!(de_auto.tool_choice, ToolChoice(Some(ToolType::OneOf))); assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto);
let ref_choice = ToolChoice(Some(ToolType::Function(FunctionName { let auto = r#"{"tool_choice":"required"}"#;
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
assert_eq!(
de_auto.tool_choice,
ChatCompletionToolChoiceOption::Required
);
let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName {
name: "myfn".to_string(), name: "myfn".to_string(),
}))); });
let named = r#"{"tool_choice":"myfn"}"#; let named = r#"{"tool_choice":"myfn"}"#;
let de_named: TestRequest = serde_json::from_str(named).unwrap(); let de_named: TestRequest = serde_json::from_str(named).unwrap();

View File

@ -28,7 +28,7 @@ use crate::{
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
}; };
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall};
use crate::{ModelInfo, ModelsInfo}; use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
@ -1551,12 +1551,11 @@ GrammarType,
Usage, Usage,
StreamOptions, StreamOptions,
DeltaToolCall, DeltaToolCall,
ToolType,
Tool, Tool,
ToolCall, ToolCall,
Function, Function,
FunctionDefinition, FunctionDefinition,
ToolChoice, ChatCompletionToolChoiceOption,
ModelInfo, ModelInfo,
) )
), ),
@ -2522,7 +2521,7 @@ pub(crate) fn prepare_chat_input(
infer: &Infer, infer: &Infer,
response_format: Option<GrammarType>, response_format: Option<GrammarType>,
tools: Option<Vec<Tool>>, tools: Option<Vec<Tool>>,
tool_choice: ToolChoice, tool_choice: ChatCompletionToolChoiceOption,
tool_prompt: &str, tool_prompt: &str,
guideline: Option<String>, guideline: Option<String>,
messages: Vec<Message>, messages: Vec<Message>,
@ -2660,7 +2659,7 @@ mod tests {
&infer, &infer,
response_format, response_format,
tools, tools,
ToolChoice(None), ChatCompletionToolChoiceOption::Auto,
tool_prompt, tool_prompt,
guideline, guideline,
messages, messages,