From 0e30e6582224f981bed3416b0bc6ce1c827fdb9a Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 22 Feb 2024 18:26:49 +0000 Subject: [PATCH] feat: respect tool choice --- clients/python/text_generation/client.py | 8 ++++ clients/python/text_generation/types.py | 2 + .../test_flash_llama_grammar_no_tools.json | 4 +- .../test_flash_llama_grammar_tools.json | 8 ++-- ...test_flash_llama_grammar_tools_choice.json | 24 ++++++++++ integration-tests/models/test_tools_llama.py | 46 +++++++++++++++---- router/src/lib.rs | 8 +++- router/src/server.rs | 27 +++++++++-- 8 files changed, 107 insertions(+), 20 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 8278094f..2bc5abc1 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -78,6 +78,7 @@ class Client: temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List[Tool]] = None, + tool_choice: Optional[str] = None, ): """ Given a list of messages, generate a response asynchronously @@ -112,6 +113,8 @@ class Client: higher are kept for generation tools (`List[Tool]`): List of tools to use + tool_choice (`str`): + The tool to use """ request = ChatRequest( @@ -129,6 +132,7 @@ class Client: temperature=temperature, top_p=top_p, tools=tools, + tool_choice=tool_choice, ) resp = requests.post( @@ -412,6 +416,7 @@ class AsyncClient: temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List[Tool]] = None, + tool_choice: Optional[str] = None, ): """ Given a list of messages, generate a response asynchronously @@ -446,6 +451,8 @@ class AsyncClient: higher are kept for generation tools (`List[Tool]`): List of tools to use + tool_choice (`str`): + The tool to use """ request = ChatRequest( @@ -463,6 +470,7 @@ class AsyncClient: temperature=temperature, top_p=top_p, tools=tools, + tool_choice=tool_choice, ) print(self.base_url) async with ClientSession( diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 1c6a1c47..8be158dd 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -86,6 +86,8 @@ class ChatRequest(BaseModel): top_p: Optional[float] = None # List of tools to be used tools: Optional[List[Tool]] = None + # Choice of tool to be used + tool_choice: Optional[str] = None class Parameters(BaseModel): diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json index 0d3211b7..206128a9 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json @@ -5,13 +5,13 @@ "index": 0, "logprobs": null, "message": { - "content": "As an up-to-date news station, our team has access to the latest information on weather conditions in Brooklyn, New York. Here is what we have learned so far:\n\n- Located in New York City, Brooklyn has a history of harsh weather patterns, especially in winter. The city's cold penchant makes it a popular winter destination, and meteorologists predict \"bomb cyclone\" conditions in the year 2021. - Due to", + "content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally", "name": null, "role": "assistant" } } ], - "created": 1708623190, + "created": 1708626137, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json index dc4561e6..003b1772 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json @@ -5,20 +5,20 @@ "index": 0, "logprobs": null, "message": { - "content": "{\"function\":{\"format\": \"celsius\", \"location\": \"Brooklyn, NYC\", \"num_days\": 1255}}", + "content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\", \"num_days\": 14}}", "name": null, "role": "assistant" } } ], - "created": 1708623212, + "created": 1708626137, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", "system_fingerprint": "1.4.2-native", "usage": { - "completion_tokens": 33, + "completion_tokens": 29, "prompt_tokens": 318, - "total_tokens": 351 + "total_tokens": 347 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json new file mode 100644 index 00000000..644f20d7 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json @@ -0,0 +1,24 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\"}}", + "name": null, + "role": "assistant" + } + } + ], + "created": 1708626030, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native", + "usage": { + "completion_tokens": 21, + "prompt_tokens": 189, + "total_tokens": 210 + } +} diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index ae051f82..8c8ce126 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -73,12 +73,12 @@ tools = [ @pytest.mark.asyncio @pytest.mark.private -async def test_flash_llama_grammar_no_tools_regex( +async def test_flash_llama_grammar_no_tools( flash_llama_grammar_tools, response_snapshot ): response = await flash_llama_grammar_tools.chat( max_tokens=100, - seed=0, + seed=1, messages=[ { "role": "system", @@ -93,19 +93,17 @@ async def test_flash_llama_grammar_no_tools_regex( assert ( response.choices[0].message.content - == 'As an up-to-date news station, our team has access to the latest information on weather conditions in Brooklyn, New York. Here is what we have learned so far:\n\n- Located in New York City, Brooklyn has a history of harsh weather patterns, especially in winter. The city\'s cold penchant makes it a popular winter destination, and meteorologists predict "bomb cyclone" conditions in the year 2021. - Due to' + == "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally" ) assert response == response_snapshot @pytest.mark.asyncio @pytest.mark.private -async def test_flash_llama_grammar_tools_regex( - flash_llama_grammar_tools, response_snapshot -): +async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): response = await flash_llama_grammar_tools.chat( max_tokens=100, - seed=0, + seed=1, tools=tools, presence_penalty=-1.1, messages=[ @@ -119,9 +117,39 @@ async def test_flash_llama_grammar_tools_regex( }, ], ) - assert len(response.choices[0].message.content) == 81 + assert len(response.choices[0].message.content) == 78 assert ( response.choices[0].message.content - == """{"function":{"format": "celsius", "location": "Brooklyn, NYC", "num_days": 1255}}""" + == """{"function":{"format": "celsius", "location": "New York, NY", "num_days": 14}}""" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_choice( + flash_llama_grammar_tools, response_snapshot +): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + tool_choice="get_current_weather", + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + assert len(response.choices[0].message.content) == 62 + assert ( + response.choices[0].message.content + == """{"function":{"format": "celsius", "location": "New York, NY"}}""" ) assert response == response_snapshot diff --git a/router/src/lib.rs b/router/src/lib.rs index 3365bf98..66213975 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -526,6 +526,11 @@ pub(crate) struct ChatRequest { #[serde(default)] #[schema(nullable = true, example = "null")] pub tools: Option>, + + /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub tool_choice: Option, } #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)] @@ -536,10 +541,9 @@ pub struct Tools { pub any_of: Vec, } -// add traut to convert to serde_json::Value for tools +// Allows Tools to be converted to a valid JSON schema object impl From for serde_json::Value { fn from(tools: Tools) -> Self { - println!("tools: {:?}", tools); let mut map = serde_json::Map::new(); let mut functions = serde_json::Map::new(); for (name, value) in tools.function { diff --git a/router/src/server.rs b/router/src/server.rs index e4764ff4..2d4944f0 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -601,10 +601,31 @@ async fn chat_completions( // if theres a tools object, we need to decompose it and use the function name as the key // and the parameters as the value in the "$functions" object. - let grammar = if let Some(req_tools) = &req.tools { + let grammar = if let Some(ref req_tools) = &req.tools { + // get the tool_choice if there is one + let tool_choice = &req.tool_choice; + let tools_to_use = if let Some(tool_choice) = tool_choice { + // get the tool based on the tool_choice + let tool = req_tools + .iter() + .find(|tool| tool.function.name == *tool_choice) + .ok_or_else(|| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Input validation error".to_string(), + error_type: "Input validation error".to_string(), + }), + ) + })?; + vec![tool.clone()] + } else { + req_tools.clone() + }; + let functions: HashMap = { let mut tools = HashMap::new(); - for tool in req_tools { + for tool in &tools_to_use { let func = tool.function.clone(); let name = func.name; let parameters = match func.parameters.as_object() { @@ -627,7 +648,7 @@ async fn chat_completions( let tools = Tools { function: functions, - any_of: req_tools + any_of: tools_to_use .iter() .map(|tool| FunctionRef::new(&tool.function.name)) .collect(),