From 7a37655d8e9209a4724950fc0d439ea16e38a397 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 26 Feb 2024 14:18:09 +0000 Subject: [PATCH] feat: improve client for tools and fix default choice --- clients/python/text_generation/client.py | 1 - clients/python/text_generation/types.py | 15 +++- .../test_flash_llama_grammar_no_tools.json | 8 +- .../test_flash_llama_grammar_tools.json | 28 +++++-- .../test_flash_llama_grammar_tools_auto.json | 38 ++++++++++ ...test_flash_llama_grammar_tools_choice.json | 25 +++++-- integration-tests/models/test_tools_llama.py | 73 ++++++++++++++++--- router/src/lib.rs | 1 + 8 files changed, 160 insertions(+), 29 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 2bc5abc1..51ebbc24 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -472,7 +472,6 @@ class AsyncClient: tools=tools, tool_choice=tool_choice, ) - print(self.base_url) async with ClientSession( headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 8be158dd..8ca46654 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -19,13 +19,24 @@ class Grammar(BaseModel): value: Union[str, dict] +class ToolCall(BaseModel): + # Id of the tool call + id: int + # Type of the tool call + type: str + # Function details of the tool call + function: dict + + class Message(BaseModel): # Role of the message sender role: str # Content of the message - content: str + content: Optional[str] # Optional name of the message sender name: Optional[str] = None + # Tool calls associated with the chat completion + tool_calls: Optional[Any] = None class Tool(BaseModel): @@ -44,6 +55,8 @@ class ChatCompletionComplete(BaseModel): logprobs: Optional[Any] # Reason for completion finish_reason: str + # Usage details of the chat completion + usage: Any class ChatComplete(BaseModel): diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json index 206128a9..3c4b4aea 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json @@ -7,11 +7,13 @@ "message": { "content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally", "name": null, - "role": "assistant" - } + "role": "assistant", + "tool_calls": null + }, + "usage": null } ], - "created": 1708626137, + "created": 1708957015, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json index 003b1772..a89501ca 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json @@ -5,20 +5,34 @@ "index": 0, "logprobs": null, "message": { - "content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\", \"num_days\": 14}}", + "content": null, "name": null, - "role": "assistant" - } + "role": "assistant", + "tool_calls": { + "function": { + "description": null, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "San Francisco", + "num_days": 2 + } + }, + "id": 0, + "type": "function" + } + }, + "usage": null } ], - "created": 1708626137, + "created": 1708957016, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", "system_fingerprint": "1.4.2-native", "usage": { - "completion_tokens": 29, - "prompt_tokens": 318, - "total_tokens": 347 + "completion_tokens": 36, + "prompt_tokens": 313, + "total_tokens": 349 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json new file mode 100644 index 00000000..a89501ca --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json @@ -0,0 +1,38 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": { + "function": { + "description": null, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "San Francisco", + "num_days": 2 + } + }, + "id": 0, + "type": "function" + } + }, + "usage": null + } + ], + "created": 1708957016, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native", + "usage": { + "completion_tokens": 36, + "prompt_tokens": 313, + "total_tokens": 349 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json index 644f20d7..83642258 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json @@ -5,20 +5,33 @@ "index": 0, "logprobs": null, "message": { - "content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\"}}", + "content": null, "name": null, - "role": "assistant" - } + "role": "assistant", + "tool_calls": { + "function": { + "description": null, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY" + } + }, + "id": 0, + "type": "function" + } + }, + "usage": null } ], - "created": 1708626030, + "created": 1708957017, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", "system_fingerprint": "1.4.2-native", "usage": { "completion_tokens": 21, - "prompt_tokens": 189, - "total_tokens": 210 + "prompt_tokens": 184, + "total_tokens": 205 } } diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 5f5ab4c8..ecabf534 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -71,7 +71,6 @@ tools = [ ] -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_no_tools( @@ -99,7 +98,6 @@ async def test_flash_llama_grammar_no_tools( assert response == response_snapshot -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): @@ -119,11 +117,59 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna }, ], ) - assert len(response.choices[0].message.content) == 78 - assert ( - response.choices[0].message.content - == """{"function":{"format": "celsius", "location": "New York, NY", "num_days": 14}}""" + assert response.choices[0].message.content == None + assert response.choices[0].message.tool_calls == { + "function": { + "description": None, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "San Francisco", + "num_days": 2, + }, + }, + "id": 0, + "type": "function", + } + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_auto( + flash_llama_grammar_tools, response_snapshot +): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + tool_choice="auto", + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], ) + assert response.choices[0].message.content == None + assert response.choices[0].message.tool_calls == { + "function": { + "description": None, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "San Francisco", + "num_days": 2, + }, + }, + "id": 0, + "type": "function", + } assert response == response_snapshot @@ -149,9 +195,14 @@ async def test_flash_llama_grammar_tools_choice( }, ], ) - assert len(response.choices[0].message.content) == 62 - assert ( - response.choices[0].message.content - == """{"function":{"format": "celsius", "location": "New York, NY"}}""" - ) + assert response.choices[0].message.content == None + assert response.choices[0].message.tool_calls == { + "id": 0, + "type": "function", + "function": { + "description": None, + "name": "tools", + "parameters": {"format": "celsius", "location": "New York, NY"}, + }, + } assert response == response_snapshot diff --git a/router/src/lib.rs b/router/src/lib.rs index 5dfa8c7d..011d9fd9 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -581,6 +581,7 @@ mod deserialize_tool_choice { Err(de::Error::custom("function key not found in tool choice")) } } + Value::Null => Ok(Some(ToolType::OneOf)), _ => Err(de::Error::custom("invalid token format")), } }