fix: prefer no_tool over notify_error to improve reponse

This commit is contained in:
David Holtz 2024-10-09 18:13:34 +00:00
parent fa140a2eeb
commit b48eca405a
6 changed files with 89 additions and 36 deletions

View File

@ -5,7 +5,7 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "There is no weather related function available to answer your prompt.", "content": "I am a helpful assistant!",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
@ -13,14 +13,14 @@
"usage": null "usage": null
} }
], ],
"created": 1728306098, "created": 1728497062,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "2.3.2-dev0-native", "system_fingerprint": "2.3.2-dev0-native",
"usage": { "usage": {
"completion_tokens": 29, "completion_tokens": 23,
"prompt_tokens": 616, "prompt_tokens": 604,
"total_tokens": 645 "total_tokens": 627
} }
} }

View File

@ -2,7 +2,7 @@
"choices": [ "choices": [
{ {
"delta": { "delta": {
"content": " prompt", "content": " assistant",
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
}, },
@ -11,7 +11,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1728494305, "created": 1728497531,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": " fans",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728497461,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -236,21 +236,18 @@ async def test_flash_llama_grammar_tools_insufficient_information(
messages=[ messages=[
{ {
"role": "system", "role": "system",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", "content": "You're a helpful assistant! Answer the users question best you can.",
}, },
{ {
"role": "user", "role": "user",
"content": "Tell me a story about 3 sea creatures", "content": "Who are you?",
}, },
], ],
stream=False, stream=False,
) )
assert responses.choices[0].message.tool_calls is None assert responses.choices[0].message.tool_calls is None
assert ( assert responses.choices[0].message.content == "I am a helpful assistant!"
responses.choices[0].message.content
== "There is no weather related function available to answer your prompt."
)
assert responses == response_snapshot assert responses == response_snapshot
@ -268,7 +265,44 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
messages=[ messages=[
{ {
"role": "system", "role": "system",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", "content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "Who are you?",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 5
assert content_generated == "I am a helpful assistant"
assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
}, },
{ {
"role": "user", "role": "user",
@ -287,10 +321,9 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
last_response = response last_response = response
assert response.choices[0].delta.tool_calls is None assert response.choices[0].delta.tool_calls is None
assert count == 11 assert count == 62
print(content_generated)
assert ( assert (
content_generated content_generated
== "There is no weather related function available to answer your prompt" == "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
) )
assert last_response == response_snapshot assert last_response == response_snapshot

View File

@ -31,25 +31,25 @@ impl ToolGrammar {
let mut tools = tools.clone(); let mut tools = tools.clone();
// add the notify_error function to the tools // add the no_tool function to the tools
let notify_error = Tool { let no_tool = Tool {
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
name: "notify_error".to_string(), name: "no_tool".to_string(),
description: Some("Notify an error or issue".to_string()), description: Some("Open ened response with no specific tool selected".to_string()),
arguments: json!({ arguments: json!({
"type": "object", "type": "object",
"properties": { "properties": {
"error": { "content": {
"type": "string", "type": "string",
"description": "The error or issue to notify" "description": "The response content",
} }
}, },
"required": ["error"] "required": ["content"]
}), }),
}, },
}; };
tools.push(notify_error); tools.push(no_tool);
// if tools are provided and no tool_choice we default to the OneOf // if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice { let tools_to_use = match tool_choice {

View File

@ -1264,7 +1264,7 @@ async fn chat_completions(
buffer.push(stream_token); buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) { if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string(); let function_name = captures[1].to_string();
if function_name == "notify_error" { if function_name == "no_tool" {
state = StreamState::BufferTrailing; state = StreamState::BufferTrailing;
response_as_tool = false; response_as_tool = false;
buffer.clear(); buffer.clear();
@ -1290,13 +1290,13 @@ async fn chat_completions(
} }
// if we skipped sending the buffer we need to avoid sending the following json key and quotes // if we skipped sending the buffer we need to avoid sending the following json key and quotes
StreamState::BufferTrailing => { StreamState::BufferTrailing => {
let infix_text = "\"error\":\""; let infix_text = "\"content\":\"";
json_buffer.push_str(&token_text.replace(" ", "")); json_buffer.push_str(&token_text.replace(" ", ""));
// keep capturing until we find the infix text // keep capturing until we find the infix text
match json_buffer.find(infix_text) { match json_buffer.find(infix_text) {
Some(error_index) => { Some(content_key_index) => {
json_buffer = json_buffer =
json_buffer[error_index + infix_text.len()..].to_string(); json_buffer[content_key_index + infix_text.len()..].to_string();
} }
None => { None => {
continue; continue;
@ -1390,18 +1390,18 @@ async fn chat_completions(
props.remove("_name"); props.remove("_name");
} }
match name.as_str() { match name.as_str() {
"notify_error" => { "no_tool" => {
// parse the error message // parse the content message
let error_message = arguments let content_message = arguments
.get("error") .get("content")
.and_then(Value::as_str) .and_then(Value::as_str)
.ok_or_else(|| { .ok_or_else(|| {
InferError::ToolError( InferError::ToolError(
"No error message found in generated text".to_string(), "No `content` found in generated text".to_string(),
) )
})? })?
.to_string(); .to_string();
(None, Some(error_message)) (None, Some(content_message))
} }
_ => { _ => {
let tool_calls = vec![ToolCall { let tool_calls = vec![ToolCall {
@ -2662,6 +2662,6 @@ mod tests {
assert!(result.is_ok()); assert!(result.is_ok());
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input"); let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
assert_eq!(using_tools, true); assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
} }
} }