feat: avoid skip tool test and avoid empty tool prompts

This commit is contained in:
drbh 2024-08-26 19:15:05 +00:00
parent 1f72dcf062
commit 20db2c3db8
4 changed files with 31 additions and 40 deletions

View File

@ -757,7 +757,12 @@ class AsyncClient:
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
payload_data = (
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
)
if payload_data == "[DONE]":
break
json_payload = json.loads(payload_data)
try:
response = ChatCompletionChunk(**json_payload)
yield response

View File

@ -36,6 +36,7 @@ tools = [
},
},
"required": ["location", "format"],
"additionalProperties": False,
},
},
},
@ -62,13 +63,13 @@ tools = [
},
},
"required": ["location", "format", "num_days"],
"additionalProperties": False,
},
},
},
]
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
max_tokens=100,
seed=1,
tools=tools,
presence_penalty=-1.1,
temperature=0.0,
messages=[
{
"role": "system",
@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"id": "0",
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
},
}
]
assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto(
@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto(
max_tokens=100,
seed=1,
tools=tools,
temperature=0.0,
tool_choice="auto",
presence_penalty=-1.1,
messages=[
{
"role": "system",
@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto(
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"id": "0",
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
},
}
]
@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto(
assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice(
@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice(
max_tokens=100,
seed=1,
tools=tools,
temperature=0.0,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice(
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"id": "0",
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
},
}
]
@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice(
assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream(
max_tokens=100,
seed=1,
tools=tools,
temperature=0.0,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses:
count += 1
assert count == 38
assert count == 48
assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information(
@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information(
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=8,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{
"role": "user",
@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
)
assert responses.choices[0].message.content is None
assert responses.choices[0].message.tool_calls == [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": None,
"name": "notify_error",
},
"id": 0,
"type": "function",
}
]
assert (
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
)
assert responses == response_snapshot

View File

@ -840,7 +840,7 @@ pub(crate) struct ChatRequest {
pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools
#[serde(default = "default_tool_prompt")]
#[serde(default)]
#[schema(
nullable = true,
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
@ -865,10 +865,8 @@ pub(crate) struct ChatRequest {
pub guideline: Option<String>,
}
fn default_tool_prompt() -> Option<String> {
Some(
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string(),
)
pub fn default_tool_prompt() -> String {
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]

View File

@ -8,7 +8,7 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::validation::ValidationError;
use crate::ChatTokenizeResponse;
use crate::{default_tool_prompt, ChatTokenizeResponse};
use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -1158,7 +1158,9 @@ async fn chat_completions(
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default();
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {