mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: simplify naming, tool choice default and improve test
This commit is contained in:
parent
dd759e7914
commit
b5bf5b32ad
@ -921,43 +921,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ChatCompletionToolChoiceOption": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "string",
|
|
||||||
"description": "Means the model can pick between generating a message or calling one or more tools.",
|
|
||||||
"enum": [
|
|
||||||
"auto"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "string",
|
|
||||||
"description": "Means the model will not call any tool and instead generates a message.",
|
|
||||||
"enum": [
|
|
||||||
"none"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "string",
|
|
||||||
"description": "Means the model must call one or more tools.",
|
|
||||||
"enum": [
|
|
||||||
"required"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"required": [
|
|
||||||
"function"
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"function": {
|
|
||||||
"$ref": "#/components/schemas/FunctionName"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"description": "<https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>"
|
|
||||||
},
|
|
||||||
"ChatCompletionTopLogprob": {
|
"ChatCompletionTopLogprob": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
@ -1052,7 +1015,7 @@
|
|||||||
"$ref": "#/components/schemas/GrammarType"
|
"$ref": "#/components/schemas/GrammarType"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"default": "null",
|
"default": "auto",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
"seed": {
|
"seed": {
|
||||||
@ -1092,7 +1055,7 @@
|
|||||||
"tool_choice": {
|
"tool_choice": {
|
||||||
"allOf": [
|
"allOf": [
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/ChatCompletionToolChoiceOption"
|
"$ref": "#/components/schemas/ToolChoice"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"default": "null",
|
"default": "null",
|
||||||
@ -2272,6 +2235,43 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"ToolChoice": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Means the model can pick between generating a message or calling one or more tools.",
|
||||||
|
"enum": [
|
||||||
|
"auto"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Means the model will not call any tool and instead generates a message.",
|
||||||
|
"enum": [
|
||||||
|
"none"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Means the model must call one or more tools.",
|
||||||
|
"enum": [
|
||||||
|
"required"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"function"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"function": {
|
||||||
|
"$ref": "#/components/schemas/FunctionName"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "<https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>"
|
||||||
|
},
|
||||||
"Url": {
|
"Url": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1729000499,
|
"created": 1729084854,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1728998230,
|
"created": 1729084850,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
@ -395,7 +395,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
|||||||
"tools": tools,
|
"tools": tools,
|
||||||
"tool_choice": {
|
"tool_choice": {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {"name": "get_current_weather"},
|
"function": {"name": "get_n_day_weather_forecast"},
|
||||||
},
|
},
|
||||||
"seed": 24,
|
"seed": 24,
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
@ -421,10 +421,9 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
|||||||
]["arguments"]
|
]["arguments"]
|
||||||
last_response = response
|
last_response = response
|
||||||
|
|
||||||
assert count == 30
|
assert count == 39
|
||||||
print(tool_calls_generated)
|
|
||||||
assert (
|
assert (
|
||||||
tool_calls_generated
|
tool_calls_generated
|
||||||
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Tokyo, JP"}}<|eot_id|>'
|
== '{"function": {"_name": "get_n_day_weather_forecast", "format": "celsius", "location": "San Francisco, CA", "num_days":3}}<|eot_id|>'
|
||||||
)
|
)
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool,
|
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
|
||||||
Properties, Tool,
|
|
||||||
};
|
};
|
||||||
use serde_json::{json, Map, Value};
|
use serde_json::{json, Map, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -20,19 +19,19 @@ impl ToolGrammar {
|
|||||||
|
|
||||||
pub fn apply(
|
pub fn apply(
|
||||||
tools: Vec<Tool>,
|
tools: Vec<Tool>,
|
||||||
tool_choice: ChatCompletionToolChoiceOption,
|
tool_choice: ToolChoice,
|
||||||
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
|
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
|
||||||
// if no tools are provided, we return None
|
// if no tools are provided, we return None and an empty vec
|
||||||
if tools.is_empty() {
|
if tools.is_empty() {
|
||||||
return Ok((tools, None));
|
return Ok((Vec::with_capacity(0), None));
|
||||||
}
|
}
|
||||||
|
|
||||||
let tools_to_use = match tool_choice {
|
let tools_to_use = match tool_choice {
|
||||||
ChatCompletionToolChoiceOption::Function(function) => {
|
ToolChoice::Function(function) => {
|
||||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||||
}
|
}
|
||||||
ChatCompletionToolChoiceOption::Required => tools,
|
ToolChoice::Required => tools,
|
||||||
ChatCompletionToolChoiceOption::Auto => {
|
ToolChoice::Auto => {
|
||||||
// only add the no_tool function if the user has selected the auto option
|
// only add the no_tool function if the user has selected the auto option
|
||||||
tools
|
tools
|
||||||
.iter()
|
.iter()
|
||||||
@ -58,7 +57,7 @@ impl ToolGrammar {
|
|||||||
}))
|
}))
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
ChatCompletionToolChoiceOption::NoTool => Vec::with_capacity(0),
|
ToolChoice::NoTool => Vec::with_capacity(0),
|
||||||
};
|
};
|
||||||
|
|
||||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||||
|
@ -893,13 +893,13 @@ pub(crate) struct ChatRequest {
|
|||||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub tool_choice: Option<ChatCompletionToolChoiceOption>,
|
pub tool_choice: Option<ToolChoice>,
|
||||||
|
|
||||||
/// Response format constraints for the generation.
|
/// Response format constraints for the generation.
|
||||||
///
|
///
|
||||||
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "auto", example = "auto")]
|
||||||
pub response_format: Option<GrammarType>,
|
pub response_format: Option<GrammarType>,
|
||||||
|
|
||||||
/// A guideline to be used in the chat_template
|
/// A guideline to be used in the chat_template
|
||||||
@ -946,14 +946,8 @@ impl ChatRequest {
|
|||||||
Some(temperature) if temperature == 0.0 => (false, None),
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
other => (true, other),
|
other => (true, other),
|
||||||
};
|
};
|
||||||
// unwrap or default (use "auto" if tools are present, and "none" if not)
|
// if no tool_choice is set, set default (Auto)
|
||||||
let tool_choice = tool_choice.unwrap_or_else(|| {
|
let tool_choice = tool_choice.unwrap_or_default();
|
||||||
if tools.is_some() {
|
|
||||||
ChatCompletionToolChoiceOption::Auto
|
|
||||||
} else {
|
|
||||||
ChatCompletionToolChoiceOption::NoTool
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if response_format.is_some() && tools.is_some() {
|
if response_format.is_some() && tools.is_some() {
|
||||||
return Err(InferError::ToolError(
|
return Err(InferError::ToolError(
|
||||||
@ -1045,21 +1039,18 @@ pub struct FunctionName {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
|
||||||
#[serde(from = "ToolTypeDeserializer")]
|
#[serde(from = "ToolTypeDeserializer")]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
/// <https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>
|
/// <https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>
|
||||||
pub enum ChatCompletionToolChoiceOption {
|
pub enum ToolChoice {
|
||||||
/// Means the model can pick between generating a message or calling one or more tools.
|
/// Means the model can pick between generating a message or calling one or more tools.
|
||||||
#[schema(rename = "auto")]
|
#[default]
|
||||||
Auto,
|
Auto,
|
||||||
/// Means the model will not call any tool and instead generates a message.
|
/// Means the model will not call any tool and instead generates a message.
|
||||||
#[schema(rename = "none")]
|
#[schema(rename = "none")]
|
||||||
#[default]
|
|
||||||
NoTool,
|
NoTool,
|
||||||
/// Means the model must call one or more tools.
|
/// Means the model must call one or more tools.
|
||||||
#[schema(rename = "required")]
|
|
||||||
Required,
|
Required,
|
||||||
/// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool.
|
/// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool.
|
||||||
#[schema(rename = "function")]
|
|
||||||
#[serde(alias = "function")]
|
|
||||||
Function(FunctionName),
|
Function(FunctionName),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1085,18 +1076,18 @@ enum ToolTypeDeserializer {
|
|||||||
TypedChoice(TypedChoice),
|
TypedChoice(TypedChoice),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<ToolTypeDeserializer> for ChatCompletionToolChoiceOption {
|
impl From<ToolTypeDeserializer> for ToolChoice {
|
||||||
fn from(value: ToolTypeDeserializer) -> Self {
|
fn from(value: ToolTypeDeserializer) -> Self {
|
||||||
match value {
|
match value {
|
||||||
ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool,
|
ToolTypeDeserializer::Null => ToolChoice::NoTool,
|
||||||
ToolTypeDeserializer::String(s) => match s.as_str() {
|
ToolTypeDeserializer::String(s) => match s.as_str() {
|
||||||
"none" => ChatCompletionToolChoiceOption::NoTool,
|
"none" => ToolChoice::NoTool,
|
||||||
"auto" => ChatCompletionToolChoiceOption::Auto,
|
"auto" => ToolChoice::Auto,
|
||||||
"required" => ChatCompletionToolChoiceOption::Required,
|
"required" => ToolChoice::Required,
|
||||||
_ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }),
|
_ => ToolChoice::Function(FunctionName { name: s }),
|
||||||
},
|
},
|
||||||
ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => {
|
ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => {
|
||||||
ChatCompletionToolChoiceOption::Function(function)
|
ToolChoice::Function(function)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1709,26 +1700,23 @@ mod tests {
|
|||||||
fn tool_choice_formats() {
|
fn tool_choice_formats() {
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct TestRequest {
|
struct TestRequest {
|
||||||
tool_choice: ChatCompletionToolChoiceOption,
|
tool_choice: ToolChoice,
|
||||||
}
|
}
|
||||||
|
|
||||||
let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap();
|
let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap();
|
||||||
assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool);
|
assert_eq!(de_none.tool_choice, ToolChoice::NoTool);
|
||||||
|
|
||||||
let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap();
|
let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap();
|
||||||
assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto);
|
assert_eq!(de_auto.tool_choice, ToolChoice::Auto);
|
||||||
|
|
||||||
let de_required: TestRequest =
|
let de_required: TestRequest =
|
||||||
serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap();
|
serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(de_required.tool_choice, ToolChoice::Required);
|
||||||
de_required.tool_choice,
|
|
||||||
ChatCompletionToolChoiceOption::Required
|
|
||||||
);
|
|
||||||
|
|
||||||
let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap();
|
let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
de_named.tool_choice,
|
de_named.tool_choice,
|
||||||
ChatCompletionToolChoiceOption::Function(FunctionName {
|
ToolChoice::Function(FunctionName {
|
||||||
name: "myfn".to_string(),
|
name: "myfn".to_string(),
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
@ -1739,7 +1727,7 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
de_openai_named.tool_choice,
|
de_openai_named.tool_choice,
|
||||||
ChatCompletionToolChoiceOption::Function(FunctionName {
|
ToolChoice::Function(FunctionName {
|
||||||
name: "myfn".to_string(),
|
name: "myfn".to_string(),
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
@ -27,7 +27,7 @@ use crate::{
|
|||||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||||
};
|
};
|
||||||
use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
||||||
use crate::{ModelInfo, ModelsInfo};
|
use crate::{ModelInfo, ModelsInfo};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
@ -1554,7 +1554,7 @@ Tool,
|
|||||||
ToolCall,
|
ToolCall,
|
||||||
Function,
|
Function,
|
||||||
FunctionDefinition,
|
FunctionDefinition,
|
||||||
ChatCompletionToolChoiceOption,
|
ToolChoice,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
Loading…
Reference in New Issue
Block a user