From b5bf5b32ad48c08cef5eea1f515fab303d597fe7 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Wed, 16 Oct 2024 13:49:46 +0000 Subject: [PATCH] fix: simplify naming, tool choice default and improve test --- docs/openapi.json | 78 +++++++++---------- ..._sea_creatures_stream_function_object.json | 2 +- ...r_tools_sea_creatures_stream_required.json | 2 +- integration-tests/models/test_tools_llama.py | 7 +- router/src/infer/tool_grammar.rs | 17 ++-- router/src/lib.rs | 52 +++++-------- router/src/server.rs | 4 +- 7 files changed, 74 insertions(+), 88 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 167bb3fb..ba53f7ee 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -921,43 +921,6 @@ } } }, - "ChatCompletionToolChoiceOption": { - "oneOf": [ - { - "type": "string", - "description": "Means the model can pick between generating a message or calling one or more tools.", - "enum": [ - "auto" - ] - }, - { - "type": "string", - "description": "Means the model will not call any tool and instead generates a message.", - "enum": [ - "none" - ] - }, - { - "type": "string", - "description": "Means the model must call one or more tools.", - "enum": [ - "required" - ] - }, - { - "type": "object", - "required": [ - "function" - ], - "properties": { - "function": { - "$ref": "#/components/schemas/FunctionName" - } - } - } - ], - "description": "" - }, "ChatCompletionTopLogprob": { "type": "object", "required": [ @@ -1052,7 +1015,7 @@ "$ref": "#/components/schemas/GrammarType" } ], - "default": "null", + "default": "auto", "nullable": true }, "seed": { @@ -1092,7 +1055,7 @@ "tool_choice": { "allOf": [ { - "$ref": "#/components/schemas/ChatCompletionToolChoiceOption" + "$ref": "#/components/schemas/ToolChoice" } ], "default": "null", @@ -2272,6 +2235,43 @@ } } }, + "ToolChoice": { + "oneOf": [ + { + "type": "string", + "description": "Means the model can pick between generating a message or calling one or more tools.", + "enum": [ + "auto" + ] + }, + { + "type": "string", + "description": "Means the model will not call any tool and instead generates a message.", + "enum": [ + "none" + ] + }, + { + "type": "string", + "description": "Means the model must call one or more tools.", + "enum": [ + "required" + ] + }, + { + "type": "object", + "required": [ + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/FunctionName" + } + } + } + ], + "description": "" + }, "Url": { "type": "object", "required": [ diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json index cf3f1fcc..e64dd49d 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json @@ -18,7 +18,7 @@ "logprobs": null } ], - "created": 1729000499, + "created": 1729084854, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json index fea26690..d8d538d6 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json @@ -19,7 +19,7 @@ "logprobs": null } ], - "created": 1728998230, + "created": 1729084850, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index ce9eb4eb..9fa993bd 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -395,7 +395,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( "tools": tools, "tool_choice": { "type": "function", - "function": {"name": "get_current_weather"}, + "function": {"name": "get_n_day_weather_forecast"}, }, "seed": 24, "max_tokens": 100, @@ -421,10 +421,9 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( ]["arguments"] last_response = response - assert count == 30 - print(tool_calls_generated) + assert count == 39 assert ( tool_calls_generated - == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Tokyo, JP"}}<|eot_id|>' + == '{"function": {"_name": "get_n_day_weather_forecast", "format": "celsius", "location": "San Francisco, CA", "num_days":3}}<|eot_id|>' ) assert last_response == response_snapshot diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 98253553..9c5ce2d8 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -1,7 +1,6 @@ use crate::infer::InferError; use crate::{ - ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, - Properties, Tool, + FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, }; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -20,19 +19,19 @@ impl ToolGrammar { pub fn apply( tools: Vec, - tool_choice: ChatCompletionToolChoiceOption, + tool_choice: ToolChoice, ) -> Result<(Vec, Option), InferError> { - // if no tools are provided, we return None + // if no tools are provided, we return None and an empty vec if tools.is_empty() { - return Ok((tools, None)); + return Ok((Vec::with_capacity(0), None)); } let tools_to_use = match tool_choice { - ChatCompletionToolChoiceOption::Function(function) => { + ToolChoice::Function(function) => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ChatCompletionToolChoiceOption::Required => tools, - ChatCompletionToolChoiceOption::Auto => { + ToolChoice::Required => tools, + ToolChoice::Auto => { // only add the no_tool function if the user has selected the auto option tools .iter() @@ -58,7 +57,7 @@ impl ToolGrammar { })) .collect::>() } - ChatCompletionToolChoiceOption::NoTool => Vec::with_capacity(0), + ToolChoice::NoTool => Vec::with_capacity(0), }; let functions: HashMap = tools_to_use diff --git a/router/src/lib.rs b/router/src/lib.rs index f76e440b..59b300dd 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -893,13 +893,13 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] - pub tool_choice: Option, + pub tool_choice: Option, /// Response format constraints for the generation. /// /// NOTE: A request can use `response_format` OR `tools` but not both. #[serde(default)] - #[schema(nullable = true, default = "null", example = "null")] + #[schema(nullable = true, default = "auto", example = "auto")] pub response_format: Option, /// A guideline to be used in the chat_template @@ -946,14 +946,8 @@ impl ChatRequest { Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - // unwrap or default (use "auto" if tools are present, and "none" if not) - let tool_choice = tool_choice.unwrap_or_else(|| { - if tools.is_some() { - ChatCompletionToolChoiceOption::Auto - } else { - ChatCompletionToolChoiceOption::NoTool - } - }); + // if no tool_choice is set, set default (Auto) + let tool_choice = tool_choice.unwrap_or_default(); if response_format.is_some() && tools.is_some() { return Err(InferError::ToolError( @@ -1045,21 +1039,18 @@ pub struct FunctionName { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)] #[serde(from = "ToolTypeDeserializer")] +#[serde(rename_all = "snake_case")] /// -pub enum ChatCompletionToolChoiceOption { +pub enum ToolChoice { /// Means the model can pick between generating a message or calling one or more tools. - #[schema(rename = "auto")] + #[default] Auto, /// Means the model will not call any tool and instead generates a message. #[schema(rename = "none")] - #[default] NoTool, /// Means the model must call one or more tools. - #[schema(rename = "required")] Required, /// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool. - #[schema(rename = "function")] - #[serde(alias = "function")] Function(FunctionName), } @@ -1085,18 +1076,18 @@ enum ToolTypeDeserializer { TypedChoice(TypedChoice), } -impl From for ChatCompletionToolChoiceOption { +impl From for ToolChoice { fn from(value: ToolTypeDeserializer) -> Self { match value { - ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool, + ToolTypeDeserializer::Null => ToolChoice::NoTool, ToolTypeDeserializer::String(s) => match s.as_str() { - "none" => ChatCompletionToolChoiceOption::NoTool, - "auto" => ChatCompletionToolChoiceOption::Auto, - "required" => ChatCompletionToolChoiceOption::Required, - _ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }), + "none" => ToolChoice::NoTool, + "auto" => ToolChoice::Auto, + "required" => ToolChoice::Required, + _ => ToolChoice::Function(FunctionName { name: s }), }, ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => { - ChatCompletionToolChoiceOption::Function(function) + ToolChoice::Function(function) } } } @@ -1709,26 +1700,23 @@ mod tests { fn tool_choice_formats() { #[derive(Deserialize)] struct TestRequest { - tool_choice: ChatCompletionToolChoiceOption, + tool_choice: ToolChoice, } let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap(); - assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool); + assert_eq!(de_none.tool_choice, ToolChoice::NoTool); let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap(); - assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto); + assert_eq!(de_auto.tool_choice, ToolChoice::Auto); let de_required: TestRequest = serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap(); - assert_eq!( - de_required.tool_choice, - ChatCompletionToolChoiceOption::Required - ); + assert_eq!(de_required.tool_choice, ToolChoice::Required); let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap(); assert_eq!( de_named.tool_choice, - ChatCompletionToolChoiceOption::Function(FunctionName { + ToolChoice::Function(FunctionName { name: "myfn".to_string(), }) ); @@ -1739,7 +1727,7 @@ mod tests { .unwrap(); assert_eq!( de_openai_named.tool_choice, - ChatCompletionToolChoiceOption::Function(FunctionName { + ToolChoice::Function(FunctionName { name: "myfn".to_string(), }) ); diff --git a/router/src/server.rs b/router/src/server.rs index 0f970391..911c77d8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -27,7 +27,7 @@ use crate::{ ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, }; -use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; @@ -1554,7 +1554,7 @@ Tool, ToolCall, Function, FunctionDefinition, -ChatCompletionToolChoiceOption, +ToolChoice, ModelInfo, ) ),