mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: simplify tool choice logic, improve tests, openapi and rust docs
This commit is contained in:
parent
f53c8059e9
commit
b2db1075e4
@ -955,7 +955,8 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"description": "<https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>"
|
||||||
},
|
},
|
||||||
"ChatCompletionTopLogprob": {
|
"ChatCompletionTopLogprob": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": {
|
||||||
|
"function": {
|
||||||
|
"arguments": "<|eot_id|>",
|
||||||
|
"name": null
|
||||||
|
},
|
||||||
|
"id": "",
|
||||||
|
"index": 0,
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1729000499,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.3.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": {
|
||||||
|
"function": {
|
||||||
|
"arguments": "<|eot_id|>",
|
||||||
|
"name": null
|
||||||
|
},
|
||||||
|
"id": "",
|
||||||
|
"index": 0,
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1728998230,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.3.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
}
|
@ -1,4 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -174,7 +176,7 @@ async def test_flash_llama_grammar_tools_choice(
|
|||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -327,3 +329,102 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream(
|
|||||||
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
|
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
|
||||||
)
|
)
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=24,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="required",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Tell me a story about 3 sea creatures",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
tool_calls_generated = ""
|
||||||
|
last_response = None
|
||||||
|
async for response in responses:
|
||||||
|
count += 1
|
||||||
|
assert response.choices[0].delta.content is None
|
||||||
|
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
|
||||||
|
last_response = response
|
||||||
|
|
||||||
|
assert count == 29
|
||||||
|
assert (
|
||||||
|
tool_calls_generated
|
||||||
|
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "San Francisco, CA"}}<|eot_id|>'
|
||||||
|
)
|
||||||
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
# using `requests` to send the request until the client library supports tool_choice as a function object
|
||||||
|
responses = requests.post(
|
||||||
|
f"{flash_llama_grammar_tools.base_url}/v1/chat/completions",
|
||||||
|
headers=flash_llama_grammar_tools.headers,
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Tell me a story about 3 sea creatures",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"tools": tools,
|
||||||
|
"tool_choice": {
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "get_current_weather"},
|
||||||
|
},
|
||||||
|
"seed": 24,
|
||||||
|
"max_tokens": 100,
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
# iterate over the response in chunks
|
||||||
|
count = 0
|
||||||
|
tool_calls_generated = ""
|
||||||
|
last_response = None
|
||||||
|
for chunk in responses.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
count += 1
|
||||||
|
# remove the "data: " prefix, trailing newline, and split the chunk into individual lines
|
||||||
|
lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n")
|
||||||
|
for line in lines:
|
||||||
|
if line == "[DONE]":
|
||||||
|
break
|
||||||
|
response = json.loads(line)
|
||||||
|
tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][
|
||||||
|
"function"
|
||||||
|
]["arguments"]
|
||||||
|
last_response = response
|
||||||
|
|
||||||
|
assert count == 30
|
||||||
|
print(tool_calls_generated)
|
||||||
|
assert (
|
||||||
|
tool_calls_generated
|
||||||
|
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Tokyo, JP"}}<|eot_id|>'
|
||||||
|
)
|
||||||
|
assert last_response == response_snapshot
|
||||||
|
@ -27,39 +27,37 @@ impl ToolGrammar {
|
|||||||
return Ok((tools, None));
|
return Ok((tools, None));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut tools = tools.clone();
|
|
||||||
|
|
||||||
// add the no_tool function to the tools as long as we are not required to use a specific tool
|
|
||||||
if tool_choice != ChatCompletionToolChoiceOption::Required {
|
|
||||||
let no_tool = Tool {
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: FunctionDefinition {
|
|
||||||
name: "no_tool".to_string(),
|
|
||||||
description: Some(
|
|
||||||
"Open ended response with no specific tool selected".to_string(),
|
|
||||||
),
|
|
||||||
arguments: json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"content": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The response content",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["content"]
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
tools.push(no_tool);
|
|
||||||
}
|
|
||||||
|
|
||||||
// if tools are provided and no tool_choice we default to the OneOf
|
|
||||||
let tools_to_use = match tool_choice {
|
let tools_to_use = match tool_choice {
|
||||||
ChatCompletionToolChoiceOption::Function(function) => {
|
ChatCompletionToolChoiceOption::Function(function) => {
|
||||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||||
}
|
}
|
||||||
ChatCompletionToolChoiceOption::Required => tools.clone(),
|
ChatCompletionToolChoiceOption::Required => tools,
|
||||||
ChatCompletionToolChoiceOption::Auto => tools.clone(),
|
ChatCompletionToolChoiceOption::Auto => {
|
||||||
|
// only add the no_tool function if the user has selected the auto option
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.chain(std::iter::once(Tool {
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionDefinition {
|
||||||
|
name: "no_tool".to_string(),
|
||||||
|
description: Some(
|
||||||
|
"Open ended response with no specific tool selected".to_string(),
|
||||||
|
),
|
||||||
|
arguments: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The response content",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["content"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)),
|
ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)),
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -121,6 +119,6 @@ impl ToolGrammar {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((tools, Some(tool_schema)))
|
Ok((tools_to_use, Some(tool_schema)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -946,21 +946,19 @@ impl ChatRequest {
|
|||||||
Some(temperature) if temperature == 0.0 => (false, None),
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
other => (true, other),
|
other => (true, other),
|
||||||
};
|
};
|
||||||
|
// unwrap or default (use "auto" if tools are present, and "none" if not)
|
||||||
|
let choice = tool_choice.unwrap_or_else(|| {
|
||||||
|
if tools.is_some() {
|
||||||
|
ChatCompletionToolChoiceOption::Auto
|
||||||
|
} else {
|
||||||
|
ChatCompletionToolChoiceOption::NoTool
|
||||||
|
}
|
||||||
|
});
|
||||||
let (inputs, grammar, using_tools) = prepare_chat_input(
|
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||||
infer,
|
infer,
|
||||||
response_format,
|
response_format,
|
||||||
tools.clone(),
|
tools,
|
||||||
// unwrap or default (use "auto" if tools are present, and "none" if not)
|
choice,
|
||||||
tool_choice.map_or_else(
|
|
||||||
|| {
|
|
||||||
if tools.is_some() {
|
|
||||||
ChatCompletionToolChoiceOption::Auto
|
|
||||||
} else {
|
|
||||||
ChatCompletionToolChoiceOption::NoTool
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|t| t,
|
|
||||||
),
|
|
||||||
&tool_prompt,
|
&tool_prompt,
|
||||||
guideline,
|
guideline,
|
||||||
messages,
|
messages,
|
||||||
@ -1023,6 +1021,7 @@ pub struct FunctionName {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
|
||||||
#[serde(from = "ToolTypeDeserializer")]
|
#[serde(from = "ToolTypeDeserializer")]
|
||||||
|
/// <https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>
|
||||||
pub enum ChatCompletionToolChoiceOption {
|
pub enum ChatCompletionToolChoiceOption {
|
||||||
/// Means the model can pick between generating a message or calling one or more tools.
|
/// Means the model can pick between generating a message or calling one or more tools.
|
||||||
#[schema(rename = "auto")]
|
#[schema(rename = "auto")]
|
||||||
@ -1034,7 +1033,7 @@ pub enum ChatCompletionToolChoiceOption {
|
|||||||
/// Means the model must call one or more tools.
|
/// Means the model must call one or more tools.
|
||||||
#[schema(rename = "required")]
|
#[schema(rename = "required")]
|
||||||
Required,
|
Required,
|
||||||
/// Forces the model to call a specific tool.
|
/// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool.
|
||||||
#[schema(rename = "function")]
|
#[schema(rename = "function")]
|
||||||
#[serde(alias = "function")]
|
#[serde(alias = "function")]
|
||||||
Function(FunctionName),
|
Function(FunctionName),
|
||||||
@ -1688,32 +1687,36 @@ mod tests {
|
|||||||
tool_choice: ChatCompletionToolChoiceOption,
|
tool_choice: ChatCompletionToolChoiceOption,
|
||||||
}
|
}
|
||||||
|
|
||||||
let none = r#"{"tool_choice":"none"}"#;
|
let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap();
|
||||||
let de_none: TestRequest = serde_json::from_str(none).unwrap();
|
|
||||||
assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool);
|
assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool);
|
||||||
|
|
||||||
let auto = r#"{"tool_choice":"auto"}"#;
|
let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap();
|
||||||
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
|
|
||||||
assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto);
|
assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto);
|
||||||
|
|
||||||
let auto = r#"{"tool_choice":"required"}"#;
|
let de_required: TestRequest =
|
||||||
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
|
serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
de_auto.tool_choice,
|
de_required.tool_choice,
|
||||||
ChatCompletionToolChoiceOption::Required
|
ChatCompletionToolChoiceOption::Required
|
||||||
);
|
);
|
||||||
|
|
||||||
let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName {
|
let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap();
|
||||||
name: "myfn".to_string(),
|
assert_eq!(
|
||||||
});
|
de_named.tool_choice,
|
||||||
|
ChatCompletionToolChoiceOption::Function(FunctionName {
|
||||||
|
name: "myfn".to_string(),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
let named = r#"{"tool_choice":"myfn"}"#;
|
let de_openai_named: TestRequest = serde_json::from_str(
|
||||||
let de_named: TestRequest = serde_json::from_str(named).unwrap();
|
r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#,
|
||||||
assert_eq!(de_named.tool_choice, ref_choice);
|
)
|
||||||
|
.unwrap();
|
||||||
let openai_named = r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#;
|
assert_eq!(
|
||||||
let de_openai_named: TestRequest = serde_json::from_str(openai_named).unwrap();
|
de_openai_named.tool_choice,
|
||||||
|
ChatCompletionToolChoiceOption::Function(FunctionName {
|
||||||
assert_eq!(de_openai_named.tool_choice, ref_choice);
|
name: "myfn".to_string(),
|
||||||
|
})
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2668,6 +2668,6 @@ mod tests {
|
|||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
|
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
|
||||||
assert_eq!(using_tools, true);
|
assert_eq!(using_tools, true);
|
||||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ended response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user