diff --git a/docs/openapi.json b/docs/openapi.json index 06c1f144..167bb3fb 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -955,7 +955,8 @@ } } } - ] + ], + "description": "" }, "ChatCompletionTopLogprob": { "type": "object", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json new file mode 100644 index 00000000..cf3f1fcc --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json @@ -0,0 +1,27 @@ +{ + "choices": [ + { + "delta": { + "role": "assistant", + "tool_calls": { + "function": { + "arguments": "<|eot_id|>", + "name": null + }, + "id": "", + "index": 0, + "type": "function" + } + }, + "finish_reason": "stop", + "index": 0, + "logprobs": null + } + ], + "created": 1729000499, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json new file mode 100644 index 00000000..fea26690 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json @@ -0,0 +1,28 @@ +{ + "choices": [ + { + "delta": { + "content": null, + "role": "assistant", + "tool_calls": { + "function": { + "arguments": "<|eot_id|>", + "name": null + }, + "id": "", + "index": 0, + "type": "function" + } + }, + "finish_reason": "stop", + "index": 0, + "logprobs": null + } + ], + "created": 1728998230, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 98e75bb4..ce9eb4eb 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -1,4 +1,6 @@ import pytest +import requests +import json @pytest.fixture(scope="module") @@ -174,7 +176,7 @@ async def test_flash_llama_grammar_tools_choice( "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, New York"}, }, } ] @@ -327,3 +329,102 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream( == "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans" ) assert last_response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_sea_creatures_stream_required( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=24, + tools=tools, + tool_choice="required", + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + stream=True, + ) + + count = 0 + tool_calls_generated = "" + last_response = None + async for response in responses: + count += 1 + assert response.choices[0].delta.content is None + tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments + last_response = response + + assert count == 29 + assert ( + tool_calls_generated + == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "San Francisco, CA"}}<|eot_id|>' + ) + assert last_response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( + flash_llama_grammar_tools, response_snapshot +): + # using `requests` to send the request until the client library supports tool_choice as a function object + responses = requests.post( + f"{flash_llama_grammar_tools.base_url}/v1/chat/completions", + headers=flash_llama_grammar_tools.headers, + json={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + "tools": tools, + "tool_choice": { + "type": "function", + "function": {"name": "get_current_weather"}, + }, + "seed": 24, + "max_tokens": 100, + "stream": True, + }, + stream=True, + ) + # iterate over the response in chunks + count = 0 + tool_calls_generated = "" + last_response = None + for chunk in responses.iter_content(chunk_size=1024): + if chunk: + count += 1 + # remove the "data: " prefix, trailing newline, and split the chunk into individual lines + lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n") + for line in lines: + if line == "[DONE]": + break + response = json.loads(line) + tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][ + "function" + ]["arguments"] + last_response = response + + assert count == 30 + print(tool_calls_generated) + assert ( + tool_calls_generated + == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Tokyo, JP"}}<|eot_id|>' + ) + assert last_response == response_snapshot diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index b9070812..77248559 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -27,39 +27,37 @@ impl ToolGrammar { return Ok((tools, None)); } - let mut tools = tools.clone(); - - // add the no_tool function to the tools as long as we are not required to use a specific tool - if tool_choice != ChatCompletionToolChoiceOption::Required { - let no_tool = Tool { - r#type: "function".to_string(), - function: FunctionDefinition { - name: "no_tool".to_string(), - description: Some( - "Open ended response with no specific tool selected".to_string(), - ), - arguments: json!({ - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The response content", - } - }, - "required": ["content"] - }), - }, - }; - tools.push(no_tool); - } - - // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { ChatCompletionToolChoiceOption::Function(function) => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ChatCompletionToolChoiceOption::Required => tools.clone(), - ChatCompletionToolChoiceOption::Auto => tools.clone(), + ChatCompletionToolChoiceOption::Required => tools, + ChatCompletionToolChoiceOption::Auto => { + // only add the no_tool function if the user has selected the auto option + tools + .iter() + .cloned() + .chain(std::iter::once(Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "no_tool".to_string(), + description: Some( + "Open ended response with no specific tool selected".to_string(), + ), + arguments: json!({ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The response content", + } + }, + "required": ["content"] + }), + }, + })) + .collect::>() + } ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)), }; @@ -121,6 +119,6 @@ impl ToolGrammar { }, }; - Ok((tools, Some(tool_schema))) + Ok((tools_to_use, Some(tool_schema))) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 15d202a8..df70f827 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -946,21 +946,19 @@ impl ChatRequest { Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; + // unwrap or default (use "auto" if tools are present, and "none" if not) + let choice = tool_choice.unwrap_or_else(|| { + if tools.is_some() { + ChatCompletionToolChoiceOption::Auto + } else { + ChatCompletionToolChoiceOption::NoTool + } + }); let (inputs, grammar, using_tools) = prepare_chat_input( infer, response_format, - tools.clone(), - // unwrap or default (use "auto" if tools are present, and "none" if not) - tool_choice.map_or_else( - || { - if tools.is_some() { - ChatCompletionToolChoiceOption::Auto - } else { - ChatCompletionToolChoiceOption::NoTool - } - }, - |t| t, - ), + tools, + choice, &tool_prompt, guideline, messages, @@ -1023,6 +1021,7 @@ pub struct FunctionName { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)] #[serde(from = "ToolTypeDeserializer")] +/// pub enum ChatCompletionToolChoiceOption { /// Means the model can pick between generating a message or calling one or more tools. #[schema(rename = "auto")] @@ -1034,7 +1033,7 @@ pub enum ChatCompletionToolChoiceOption { /// Means the model must call one or more tools. #[schema(rename = "required")] Required, - /// Forces the model to call a specific tool. + /// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool. #[schema(rename = "function")] #[serde(alias = "function")] Function(FunctionName), @@ -1688,32 +1687,36 @@ mod tests { tool_choice: ChatCompletionToolChoiceOption, } - let none = r#"{"tool_choice":"none"}"#; - let de_none: TestRequest = serde_json::from_str(none).unwrap(); + let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap(); assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool); - let auto = r#"{"tool_choice":"auto"}"#; - let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); + let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap(); assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto); - let auto = r#"{"tool_choice":"required"}"#; - let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); + let de_required: TestRequest = + serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap(); assert_eq!( - de_auto.tool_choice, + de_required.tool_choice, ChatCompletionToolChoiceOption::Required ); - let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName { - name: "myfn".to_string(), - }); + let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap(); + assert_eq!( + de_named.tool_choice, + ChatCompletionToolChoiceOption::Function(FunctionName { + name: "myfn".to_string(), + }) + ); - let named = r#"{"tool_choice":"myfn"}"#; - let de_named: TestRequest = serde_json::from_str(named).unwrap(); - assert_eq!(de_named.tool_choice, ref_choice); - - let openai_named = r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#; - let de_openai_named: TestRequest = serde_json::from_str(openai_named).unwrap(); - - assert_eq!(de_openai_named.tool_choice, ref_choice); + let de_openai_named: TestRequest = serde_json::from_str( + r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#, + ) + .unwrap(); + assert_eq!( + de_openai_named.tool_choice, + ChatCompletionToolChoiceOption::Function(FunctionName { + name: "myfn".to_string(), + }) + ); } } diff --git a/router/src/server.rs b/router/src/server.rs index 26a43f0a..c46351e5 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2668,6 +2668,6 @@ mod tests { assert!(result.is_ok()); let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input"); assert_eq!(using_tools, true); - assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); + assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ended response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); } }