mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: simplify tool choice logic, improve tests, openapi and rust docs
This commit is contained in:
parent
f53c8059e9
commit
b2db1075e4
@ -955,7 +955,8 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
],
|
||||
"description": "<https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>"
|
||||
},
|
||||
"ChatCompletionTopLogprob": {
|
||||
"type": "object",
|
||||
|
@ -0,0 +1,27 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"arguments": "<|eot_id|>",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1729000499,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.3.2-dev0-native",
|
||||
"usage": null
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"arguments": "<|eot_id|>",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1728998230,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.3.2-dev0-native",
|
||||
"usage": null
|
||||
}
|
@ -1,4 +1,6 @@
|
||||
import pytest
|
||||
import requests
|
||||
import json
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -174,7 +176,7 @@ async def test_flash_llama_grammar_tools_choice(
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
|
||||
},
|
||||
}
|
||||
]
|
||||
@ -327,3 +329,102 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream(
|
||||
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
|
||||
)
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=24,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Tell me a story about 3 sea creatures",
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
count = 0
|
||||
tool_calls_generated = ""
|
||||
last_response = None
|
||||
async for response in responses:
|
||||
count += 1
|
||||
assert response.choices[0].delta.content is None
|
||||
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
|
||||
last_response = response
|
||||
|
||||
assert count == 29
|
||||
assert (
|
||||
tool_calls_generated
|
||||
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "San Francisco, CA"}}<|eot_id|>'
|
||||
)
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
# using `requests` to send the request until the client library supports tool_choice as a function object
|
||||
responses = requests.post(
|
||||
f"{flash_llama_grammar_tools.base_url}/v1/chat/completions",
|
||||
headers=flash_llama_grammar_tools.headers,
|
||||
json={
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Tell me a story about 3 sea creatures",
|
||||
},
|
||||
],
|
||||
"tools": tools,
|
||||
"tool_choice": {
|
||||
"type": "function",
|
||||
"function": {"name": "get_current_weather"},
|
||||
},
|
||||
"seed": 24,
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
# iterate over the response in chunks
|
||||
count = 0
|
||||
tool_calls_generated = ""
|
||||
last_response = None
|
||||
for chunk in responses.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
count += 1
|
||||
# remove the "data: " prefix, trailing newline, and split the chunk into individual lines
|
||||
lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n")
|
||||
for line in lines:
|
||||
if line == "[DONE]":
|
||||
break
|
||||
response = json.loads(line)
|
||||
tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][
|
||||
"function"
|
||||
]["arguments"]
|
||||
last_response = response
|
||||
|
||||
assert count == 30
|
||||
print(tool_calls_generated)
|
||||
assert (
|
||||
tool_calls_generated
|
||||
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Tokyo, JP"}}<|eot_id|>'
|
||||
)
|
||||
assert last_response == response_snapshot
|
||||
|
@ -27,39 +27,37 @@ impl ToolGrammar {
|
||||
return Ok((tools, None));
|
||||
}
|
||||
|
||||
let mut tools = tools.clone();
|
||||
|
||||
// add the no_tool function to the tools as long as we are not required to use a specific tool
|
||||
if tool_choice != ChatCompletionToolChoiceOption::Required {
|
||||
let no_tool = Tool {
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
name: "no_tool".to_string(),
|
||||
description: Some(
|
||||
"Open ended response with no specific tool selected".to_string(),
|
||||
),
|
||||
arguments: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The response content",
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
}),
|
||||
},
|
||||
};
|
||||
tools.push(no_tool);
|
||||
}
|
||||
|
||||
// if tools are provided and no tool_choice we default to the OneOf
|
||||
let tools_to_use = match tool_choice {
|
||||
ChatCompletionToolChoiceOption::Function(function) => {
|
||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||
}
|
||||
ChatCompletionToolChoiceOption::Required => tools.clone(),
|
||||
ChatCompletionToolChoiceOption::Auto => tools.clone(),
|
||||
ChatCompletionToolChoiceOption::Required => tools,
|
||||
ChatCompletionToolChoiceOption::Auto => {
|
||||
// only add the no_tool function if the user has selected the auto option
|
||||
tools
|
||||
.iter()
|
||||
.cloned()
|
||||
.chain(std::iter::once(Tool {
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
name: "no_tool".to_string(),
|
||||
description: Some(
|
||||
"Open ended response with no specific tool selected".to_string(),
|
||||
),
|
||||
arguments: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The response content",
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
}),
|
||||
},
|
||||
}))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)),
|
||||
};
|
||||
|
||||
@ -121,6 +119,6 @@ impl ToolGrammar {
|
||||
},
|
||||
};
|
||||
|
||||
Ok((tools, Some(tool_schema)))
|
||||
Ok((tools_to_use, Some(tool_schema)))
|
||||
}
|
||||
}
|
||||
|
@ -946,21 +946,19 @@ impl ChatRequest {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
// unwrap or default (use "auto" if tools are present, and "none" if not)
|
||||
let choice = tool_choice.unwrap_or_else(|| {
|
||||
if tools.is_some() {
|
||||
ChatCompletionToolChoiceOption::Auto
|
||||
} else {
|
||||
ChatCompletionToolChoiceOption::NoTool
|
||||
}
|
||||
});
|
||||
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||
infer,
|
||||
response_format,
|
||||
tools.clone(),
|
||||
// unwrap or default (use "auto" if tools are present, and "none" if not)
|
||||
tool_choice.map_or_else(
|
||||
|| {
|
||||
if tools.is_some() {
|
||||
ChatCompletionToolChoiceOption::Auto
|
||||
} else {
|
||||
ChatCompletionToolChoiceOption::NoTool
|
||||
}
|
||||
},
|
||||
|t| t,
|
||||
),
|
||||
tools,
|
||||
choice,
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
@ -1023,6 +1021,7 @@ pub struct FunctionName {
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
|
||||
#[serde(from = "ToolTypeDeserializer")]
|
||||
/// <https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>
|
||||
pub enum ChatCompletionToolChoiceOption {
|
||||
/// Means the model can pick between generating a message or calling one or more tools.
|
||||
#[schema(rename = "auto")]
|
||||
@ -1034,7 +1033,7 @@ pub enum ChatCompletionToolChoiceOption {
|
||||
/// Means the model must call one or more tools.
|
||||
#[schema(rename = "required")]
|
||||
Required,
|
||||
/// Forces the model to call a specific tool.
|
||||
/// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool.
|
||||
#[schema(rename = "function")]
|
||||
#[serde(alias = "function")]
|
||||
Function(FunctionName),
|
||||
@ -1688,32 +1687,36 @@ mod tests {
|
||||
tool_choice: ChatCompletionToolChoiceOption,
|
||||
}
|
||||
|
||||
let none = r#"{"tool_choice":"none"}"#;
|
||||
let de_none: TestRequest = serde_json::from_str(none).unwrap();
|
||||
let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap();
|
||||
assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool);
|
||||
|
||||
let auto = r#"{"tool_choice":"auto"}"#;
|
||||
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
|
||||
let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap();
|
||||
assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto);
|
||||
|
||||
let auto = r#"{"tool_choice":"required"}"#;
|
||||
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
|
||||
let de_required: TestRequest =
|
||||
serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap();
|
||||
assert_eq!(
|
||||
de_auto.tool_choice,
|
||||
de_required.tool_choice,
|
||||
ChatCompletionToolChoiceOption::Required
|
||||
);
|
||||
|
||||
let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName {
|
||||
name: "myfn".to_string(),
|
||||
});
|
||||
let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap();
|
||||
assert_eq!(
|
||||
de_named.tool_choice,
|
||||
ChatCompletionToolChoiceOption::Function(FunctionName {
|
||||
name: "myfn".to_string(),
|
||||
})
|
||||
);
|
||||
|
||||
let named = r#"{"tool_choice":"myfn"}"#;
|
||||
let de_named: TestRequest = serde_json::from_str(named).unwrap();
|
||||
assert_eq!(de_named.tool_choice, ref_choice);
|
||||
|
||||
let openai_named = r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#;
|
||||
let de_openai_named: TestRequest = serde_json::from_str(openai_named).unwrap();
|
||||
|
||||
assert_eq!(de_openai_named.tool_choice, ref_choice);
|
||||
let de_openai_named: TestRequest = serde_json::from_str(
|
||||
r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
de_openai_named.tool_choice,
|
||||
ChatCompletionToolChoiceOption::Function(FunctionName {
|
||||
name: "myfn".to_string(),
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -2668,6 +2668,6 @@ mod tests {
|
||||
assert!(result.is_ok());
|
||||
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
|
||||
assert_eq!(using_tools, true);
|
||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ended response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user