feat: avoid skip tool test and avoid empty tool prompts

This commit is contained in:
drbh 2024-08-26 19:15:05 +00:00
parent 1f72dcf062
commit 20db2c3db8
4 changed files with 31 additions and 40 deletions

View File

@ -757,7 +757,12 @@ class AsyncClient:
continue continue
payload = byte_payload.decode("utf-8") payload = byte_payload.decode("utf-8")
if payload.startswith("data:"): if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) payload_data = (
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
)
if payload_data == "[DONE]":
break
json_payload = json.loads(payload_data)
try: try:
response = ChatCompletionChunk(**json_payload) response = ChatCompletionChunk(**json_payload)
yield response yield response

View File

@ -36,6 +36,7 @@ tools = [
}, },
}, },
"required": ["location", "format"], "required": ["location", "format"],
"additionalProperties": False,
}, },
}, },
}, },
@ -62,13 +63,13 @@ tools = [
}, },
}, },
"required": ["location", "format", "num_days"], "required": ["location", "format", "num_days"],
"additionalProperties": False,
}, },
}, },
}, },
] ]
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
max_tokens=100, max_tokens=100,
seed=1, seed=1,
tools=tools, tools=tools,
presence_penalty=-1.1, temperature=0.0,
messages=[ messages=[
{ {
"role": "system", "role": "system",
@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
assert response.choices[0].message.content is None assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0, "id": "0",
"type": "function", "type": "function",
"function": { "function": {
"description": None, "description": None,
"name": "get_current_weather", "name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"}, "arguments": {"format": "celsius", "location": "Brooklyn, NY"},
}, },
} }
] ]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_auto( async def test_flash_llama_grammar_tools_auto(
@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto(
max_tokens=100, max_tokens=100,
seed=1, seed=1,
tools=tools, tools=tools,
temperature=0.0,
tool_choice="auto", tool_choice="auto",
presence_penalty=-1.1,
messages=[ messages=[
{ {
"role": "system", "role": "system",
@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto(
assert response.choices[0].message.content is None assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0, "id": "0",
"type": "function", "type": "function",
"function": { "function": {
"description": None, "description": None,
"name": "get_current_weather", "name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"}, "arguments": {"format": "celsius", "location": "Brooklyn, NY"},
}, },
} }
] ]
@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_choice( async def test_flash_llama_grammar_tools_choice(
@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice(
max_tokens=100, max_tokens=100,
seed=1, seed=1,
tools=tools, tools=tools,
temperature=0.0,
tool_choice="get_current_weather", tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[ messages=[
{ {
"role": "system", "role": "system",
@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice(
assert response.choices[0].message.content is None assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0, "id": "0",
"type": "function", "type": "function",
"function": { "function": {
"description": None, "description": None,
"name": "get_current_weather", "name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"}, "arguments": {"format": "celsius", "location": "Brooklyn, NY"},
}, },
} }
] ]
@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_stream( async def test_flash_llama_grammar_tools_stream(
@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream(
max_tokens=100, max_tokens=100,
seed=1, seed=1,
tools=tools, tools=tools,
temperature=0.0,
tool_choice="get_current_weather", tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[ messages=[
{ {
"role": "system", "role": "system",
@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses: async for response in responses:
count += 1 count += 1
assert count == 38 assert count == 48
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information( async def test_flash_llama_grammar_tools_insufficient_information(
@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information(
): ):
responses = await flash_llama_grammar_tools.chat( responses = await flash_llama_grammar_tools.chat(
max_tokens=100, max_tokens=100,
seed=8, seed=24,
tools=tools, tools=tools,
tool_choice="auto", tool_choice="auto",
messages=[ messages=[
{ {
"role": "system", "role": "system",
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", "content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
}, },
{ {
"role": "user", "role": "user",
@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
) )
assert responses.choices[0].message.content is None assert responses.choices[0].message.content is None
assert responses.choices[0].message.tool_calls == [ assert (
{ responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
"function": { )
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": None,
"name": "notify_error",
},
"id": 0,
"type": "function",
}
]
assert responses == response_snapshot assert responses == response_snapshot

View File

@ -840,7 +840,7 @@ pub(crate) struct ChatRequest {
pub tools: Option<Vec<Tool>>, pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools /// A prompt to be appended before the tools
#[serde(default = "default_tool_prompt")] #[serde(default)]
#[schema( #[schema(
nullable = true, nullable = true,
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables." example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
@ -865,10 +865,8 @@ pub(crate) struct ChatRequest {
pub guideline: Option<String>, pub guideline: Option<String>,
} }
fn default_tool_prompt() -> Option<String> { pub fn default_tool_prompt() -> String {
Some( "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string(),
)
} }
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]

View File

@ -8,7 +8,7 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready, kserve_model_metadata, kserve_model_metadata_ready,
}; };
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::ChatTokenizeResponse; use crate::{default_tool_prompt, ChatTokenizeResponse};
use crate::{ use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -1158,7 +1158,9 @@ async fn chat_completions(
let repetition_penalty = presence_penalty.map(|x| x + 2.0); let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100)); let max_new_tokens = max_tokens.or(Some(100));
let logprobs = logprobs.unwrap_or(false); let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default(); let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default(); let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0 // enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature { let (do_sample, temperature) = match temperature {