From 407531708e492abc74ae35bd871437174631ae01 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Fri, 18 Oct 2024 15:36:40 +0000 Subject: [PATCH] fix: adjust tool choice none logic, add test and small refactors --- docs/openapi.json | 4 +- integration-tests/models/test_tools_llama.py | 41 ++++++++++++++++++++ router/src/infer/tool_grammar.rs | 40 +++++++++---------- router/src/lib.rs | 38 +++++++++--------- 4 files changed, 81 insertions(+), 42 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index ba53f7ee..08aab6c5 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1015,7 +1015,7 @@ "$ref": "#/components/schemas/GrammarType" } ], - "default": "auto", + "default": "null", "nullable": true }, "seed": { @@ -1058,7 +1058,7 @@ "$ref": "#/components/schemas/ToolChoice" } ], - "default": "null", + "default": "auto", "nullable": true }, "tool_prompt": { diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 9fa993bd..b5821945 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -371,6 +371,47 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required( assert last_response == response_snapshot +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_sea_creatures_stream_none( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=24, + tools=tools, + tool_choice="none", + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + stream=True, + ) + + count = 0 + content_generated = "" + last_response = None + async for response in responses: + count += 1 + content_generated += response.choices[0].delta.content + last_response = response + assert response.choices[0].delta.tool_calls is None + + assert count == 100 + print(content_generated) + assert ( + content_generated + == "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep" + ) + assert last_response == response_snapshot + + @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 9c5ce2d8..7770cd9d 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -20,12 +20,7 @@ impl ToolGrammar { pub fn apply( tools: Vec, tool_choice: ToolChoice, - ) -> Result<(Vec, Option), InferError> { - // if no tools are provided, we return None and an empty vec - if tools.is_empty() { - return Ok((Vec::with_capacity(0), None)); - } - + ) -> Result, JsonSchemaTool)>, InferError> { let tools_to_use = match tool_choice { ToolChoice::Function(function) => { vec![Self::find_tool_by_name(&tools, &function.name)?] @@ -57,9 +52,14 @@ impl ToolGrammar { })) .collect::>() } - ToolChoice::NoTool => Vec::with_capacity(0), + ToolChoice::NoTool => vec![], }; + // if no tools are provided or if the user has selected the no_tool option, return None + if tools_to_use.is_empty() { + return Ok(None); + } + let functions: HashMap = tools_to_use .iter() .map(|tool| { @@ -106,22 +106,18 @@ impl ToolGrammar { }) .collect(); - let tool_schema = if tools_to_use.is_empty() { - None - } else { - Some(JsonSchemaTool { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .collect(), - }, - }) + let tool_schema = JsonSchemaTool { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .collect(), + }, }; - Ok((tools_to_use, tool_schema)) + Ok(Some((tools_to_use, tool_schema))) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 59b300dd..009b57cf 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -892,14 +892,14 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] - #[schema(nullable = true, default = "null", example = "null")] - pub tool_choice: Option, + #[schema(nullable = true, default = "auto", example = "auto")] + pub tool_choice: ToolChoice, /// Response format constraints for the generation. /// /// NOTE: A request can use `response_format` OR `tools` but not both. #[serde(default)] - #[schema(nullable = true, default = "auto", example = "auto")] + #[schema(nullable = true, default = "null", example = "null")] pub response_format: Option, /// A guideline to be used in the chat_template @@ -946,8 +946,6 @@ impl ChatRequest { Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - // if no tool_choice is set, set default (Auto) - let tool_choice = tool_choice.unwrap_or_default(); if response_format.is_some() && tools.is_some() { return Err(InferError::ToolError( @@ -962,18 +960,22 @@ impl ChatRequest { } None => { if let Some(tools) = tools { - let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?; - - let grammar = tool_schema - .as_ref() - .map(|t| GrammarType::Json(serde_json::json!(t))); - - let inputs: String = infer.apply_chat_template( - guideline, - messages, - Some((updated_tools, tool_prompt)), - )?; - (inputs, grammar, tool_schema.is_some()) + match ToolGrammar::apply(tools, tool_choice)? { + Some((updated_tools, tool_schema)) => { + let grammar = GrammarType::Json(serde_json::json!(tool_schema)); + let inputs: String = infer.apply_chat_template( + guideline, + messages, + Some((updated_tools, tool_prompt)), + )?; + (inputs, Some(grammar), true) + } + None => { + // same as if no response_format or tools are set + let inputs = infer.apply_chat_template(guideline, messages, None)?; + (inputs, None, false) + } + } } else { // if no response_format or tools are set simply apply the chat template to generate inputs let inputs = infer.apply_chat_template(guideline, messages, None)?; @@ -1046,7 +1048,7 @@ pub enum ToolChoice { #[default] Auto, /// Means the model will not call any tool and instead generates a message. - #[schema(rename = "none")] + #[serde(rename = "none")] NoTool, /// Means the model must call one or more tools. Required,