From 20db2c3db8731cccbbc709d5f08db46587a78e87 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 26 Aug 2024 19:15:05 +0000 Subject: [PATCH] feat: avoid skip tool test and avoid empty tool prompts --- clients/python/text_generation/client.py | 7 ++- integration-tests/models/test_tools_llama.py | 50 +++++++------------- router/src/lib.rs | 8 ++-- router/src/server.rs | 6 ++- 4 files changed, 31 insertions(+), 40 deletions(-) diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 12966747..45301b63 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -757,7 +757,12 @@ class AsyncClient: continue payload = byte_payload.decode("utf-8") if payload.startswith("data:"): - json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + payload_data = ( + payload.lstrip("data:").rstrip("\n").removeprefix(" ") + ) + if payload_data == "[DONE]": + break + json_payload = json.loads(payload_data) try: response = ChatCompletionChunk(**json_payload) yield response diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index f831990a..9855cfda 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -36,6 +36,7 @@ tools = [ }, }, "required": ["location", "format"], + "additionalProperties": False, }, }, }, @@ -62,13 +63,13 @@ tools = [ }, }, "required": ["location", "format", "num_days"], + "additionalProperties": False, }, }, }, ] -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): @@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna max_tokens=100, seed=1, tools=tools, - presence_penalty=-1.1, + temperature=0.0, messages=[ { "role": "system", @@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_auto( @@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="auto", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto( assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] @@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto( assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_choice( @@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="get_current_weather", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice( assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] @@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice( assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_stream( @@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="get_current_weather", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream( async for response in responses: count += 1 - assert count == 38 + assert count == 48 assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_insufficient_information( @@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information( ): responses = await flash_llama_grammar_tools.chat( max_tokens=100, - seed=8, + seed=24, tools=tools, tool_choice="auto", messages=[ { "role": "system", - "content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", + "content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", }, { "role": "user", @@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information( ) assert responses.choices[0].message.content is None - assert responses.choices[0].message.tool_calls == [ - { - "function": { - "arguments": { - "error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options." - }, - "description": None, - "name": "notify_error", - }, - "id": 0, - "type": "function", - } - ] - + assert ( + responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error" + ) assert responses == response_snapshot diff --git a/router/src/lib.rs b/router/src/lib.rs index af77d436..ce4f7c46 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -840,7 +840,7 @@ pub(crate) struct ChatRequest { pub tools: Option>, /// A prompt to be appended before the tools - #[serde(default = "default_tool_prompt")] + #[serde(default)] #[schema( nullable = true, example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables." @@ -865,10 +865,8 @@ pub(crate) struct ChatRequest { pub guideline: Option, } -fn default_tool_prompt() -> Option { - Some( - "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string(), - ) +pub fn default_tool_prompt() -> String { + "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index f8e256bc..8ebd1a33 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,7 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; -use crate::ChatTokenizeResponse; +use crate::{default_tool_prompt, ChatTokenizeResponse}; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -1158,7 +1158,9 @@ async fn chat_completions( let repetition_penalty = presence_penalty.map(|x| x + 2.0); let max_new_tokens = max_tokens.or(Some(100)); let logprobs = logprobs.unwrap_or(false); - let tool_prompt = tool_prompt.unwrap_or_default(); + let tool_prompt = tool_prompt + .filter(|s| !s.is_empty()) + .unwrap_or_else(default_tool_prompt); let stop = stop.unwrap_or_default(); // enable greedy only when temperature is 0 let (do_sample, temperature) = match temperature {