fix: simplify naming, tool choice default and improve test

This commit is contained in:
David Holtz 2024-10-16 13:49:46 +00:00 committed by drbh
parent dd759e7914
commit b5bf5b32ad
7 changed files with 74 additions and 88 deletions

View File

@ -921,43 +921,6 @@
}
}
},
"ChatCompletionToolChoiceOption": {
"oneOf": [
{
"type": "string",
"description": "Means the model can pick between generating a message or calling one or more tools.",
"enum": [
"auto"
]
},
{
"type": "string",
"description": "Means the model will not call any tool and instead generates a message.",
"enum": [
"none"
]
},
{
"type": "string",
"description": "Means the model must call one or more tools.",
"enum": [
"required"
]
},
{
"type": "object",
"required": [
"function"
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
}
],
"description": "<https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>"
},
"ChatCompletionTopLogprob": {
"type": "object",
"required": [
@ -1052,7 +1015,7 @@
"$ref": "#/components/schemas/GrammarType"
}
],
"default": "null",
"default": "auto",
"nullable": true
},
"seed": {
@ -1092,7 +1055,7 @@
"tool_choice": {
"allOf": [
{
"$ref": "#/components/schemas/ChatCompletionToolChoiceOption"
"$ref": "#/components/schemas/ToolChoice"
}
],
"default": "null",
@ -2272,6 +2235,43 @@
}
}
},
"ToolChoice": {
"oneOf": [
{
"type": "string",
"description": "Means the model can pick between generating a message or calling one or more tools.",
"enum": [
"auto"
]
},
{
"type": "string",
"description": "Means the model will not call any tool and instead generates a message.",
"enum": [
"none"
]
},
{
"type": "string",
"description": "Means the model must call one or more tools.",
"enum": [
"required"
]
},
{
"type": "object",
"required": [
"function"
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
}
],
"description": "<https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>"
},
"Url": {
"type": "object",
"required": [

View File

@ -18,7 +18,7 @@
"logprobs": null
}
],
"created": 1729000499,
"created": 1729084854,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -19,7 +19,7 @@
"logprobs": null
}
],
"created": 1728998230,
"created": 1729084850,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -395,7 +395,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
"tools": tools,
"tool_choice": {
"type": "function",
"function": {"name": "get_current_weather"},
"function": {"name": "get_n_day_weather_forecast"},
},
"seed": 24,
"max_tokens": 100,
@ -421,10 +421,9 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
]["arguments"]
last_response = response
assert count == 30
print(tool_calls_generated)
assert count == 39
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Tokyo, JP"}}<|eot_id|>'
== '{"function": {"_name": "get_n_day_weather_forecast", "format": "celsius", "location": "San Francisco, CA", "num_days":3}}<|eot_id|>'
)
assert last_response == response_snapshot

View File

@ -1,7 +1,6 @@
use crate::infer::InferError;
use crate::{
ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool,
Properties, Tool,
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
@ -20,19 +19,19 @@ impl ToolGrammar {
pub fn apply(
tools: Vec<Tool>,
tool_choice: ChatCompletionToolChoiceOption,
tool_choice: ToolChoice,
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
// if no tools are provided, we return None
// if no tools are provided, we return None and an empty vec
if tools.is_empty() {
return Ok((tools, None));
return Ok((Vec::with_capacity(0), None));
}
let tools_to_use = match tool_choice {
ChatCompletionToolChoiceOption::Function(function) => {
ToolChoice::Function(function) => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ChatCompletionToolChoiceOption::Required => tools,
ChatCompletionToolChoiceOption::Auto => {
ToolChoice::Required => tools,
ToolChoice::Auto => {
// only add the no_tool function if the user has selected the auto option
tools
.iter()
@ -58,7 +57,7 @@ impl ToolGrammar {
}))
.collect::<Vec<_>>()
}
ChatCompletionToolChoiceOption::NoTool => Vec::with_capacity(0),
ToolChoice::NoTool => Vec::with_capacity(0),
};
let functions: HashMap<String, serde_json::Value> = tools_to_use

View File

@ -893,13 +893,13 @@ pub(crate) struct ChatRequest {
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub tool_choice: Option<ChatCompletionToolChoiceOption>,
pub tool_choice: Option<ToolChoice>,
/// Response format constraints for the generation.
///
/// NOTE: A request can use `response_format` OR `tools` but not both.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
#[schema(nullable = true, default = "auto", example = "auto")]
pub response_format: Option<GrammarType>,
/// A guideline to be used in the chat_template
@ -946,14 +946,8 @@ impl ChatRequest {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
// unwrap or default (use "auto" if tools are present, and "none" if not)
let tool_choice = tool_choice.unwrap_or_else(|| {
if tools.is_some() {
ChatCompletionToolChoiceOption::Auto
} else {
ChatCompletionToolChoiceOption::NoTool
}
});
// if no tool_choice is set, set default (Auto)
let tool_choice = tool_choice.unwrap_or_default();
if response_format.is_some() && tools.is_some() {
return Err(InferError::ToolError(
@ -1045,21 +1039,18 @@ pub struct FunctionName {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
#[serde(from = "ToolTypeDeserializer")]
#[serde(rename_all = "snake_case")]
/// <https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>
pub enum ChatCompletionToolChoiceOption {
pub enum ToolChoice {
/// Means the model can pick between generating a message or calling one or more tools.
#[schema(rename = "auto")]
#[default]
Auto,
/// Means the model will not call any tool and instead generates a message.
#[schema(rename = "none")]
#[default]
NoTool,
/// Means the model must call one or more tools.
#[schema(rename = "required")]
Required,
/// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool.
#[schema(rename = "function")]
#[serde(alias = "function")]
Function(FunctionName),
}
@ -1085,18 +1076,18 @@ enum ToolTypeDeserializer {
TypedChoice(TypedChoice),
}
impl From<ToolTypeDeserializer> for ChatCompletionToolChoiceOption {
impl From<ToolTypeDeserializer> for ToolChoice {
fn from(value: ToolTypeDeserializer) -> Self {
match value {
ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool,
ToolTypeDeserializer::Null => ToolChoice::NoTool,
ToolTypeDeserializer::String(s) => match s.as_str() {
"none" => ChatCompletionToolChoiceOption::NoTool,
"auto" => ChatCompletionToolChoiceOption::Auto,
"required" => ChatCompletionToolChoiceOption::Required,
_ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }),
"none" => ToolChoice::NoTool,
"auto" => ToolChoice::Auto,
"required" => ToolChoice::Required,
_ => ToolChoice::Function(FunctionName { name: s }),
},
ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => {
ChatCompletionToolChoiceOption::Function(function)
ToolChoice::Function(function)
}
}
}
@ -1709,26 +1700,23 @@ mod tests {
fn tool_choice_formats() {
#[derive(Deserialize)]
struct TestRequest {
tool_choice: ChatCompletionToolChoiceOption,
tool_choice: ToolChoice,
}
let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap();
assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool);
assert_eq!(de_none.tool_choice, ToolChoice::NoTool);
let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap();
assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto);
assert_eq!(de_auto.tool_choice, ToolChoice::Auto);
let de_required: TestRequest =
serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap();
assert_eq!(
de_required.tool_choice,
ChatCompletionToolChoiceOption::Required
);
assert_eq!(de_required.tool_choice, ToolChoice::Required);
let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap();
assert_eq!(
de_named.tool_choice,
ChatCompletionToolChoiceOption::Function(FunctionName {
ToolChoice::Function(FunctionName {
name: "myfn".to_string(),
})
);
@ -1739,7 +1727,7 @@ mod tests {
.unwrap();
assert_eq!(
de_openai_named.tool_choice,
ChatCompletionToolChoiceOption::Function(FunctionName {
ToolChoice::Function(FunctionName {
name: "myfn".to_string(),
})
);

View File

@ -27,7 +27,7 @@ use crate::{
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
};
use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
@ -1554,7 +1554,7 @@ Tool,
ToolCall,
Function,
FunctionDefinition,
ChatCompletionToolChoiceOption,
ToolChoice,
ModelInfo,
)
),