mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: improve, simplify and rename tool choice struct add required support and refactor
This commit is contained in:
parent
209f841767
commit
f53c8059e9
@ -921,6 +921,42 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"ChatCompletionToolChoiceOption": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Means the model can pick between generating a message or calling one or more tools.",
|
||||||
|
"enum": [
|
||||||
|
"auto"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Means the model will not call any tool and instead generates a message.",
|
||||||
|
"enum": [
|
||||||
|
"none"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Means the model must call one or more tools.",
|
||||||
|
"enum": [
|
||||||
|
"required"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"function"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"function": {
|
||||||
|
"$ref": "#/components/schemas/FunctionName"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"ChatCompletionTopLogprob": {
|
"ChatCompletionTopLogprob": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
@ -1055,9 +1091,10 @@
|
|||||||
"tool_choice": {
|
"tool_choice": {
|
||||||
"allOf": [
|
"allOf": [
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/ToolChoice"
|
"$ref": "#/components/schemas/ChatCompletionToolChoiceOption"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
"default": "null",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
"tool_prompt": {
|
"tool_prompt": {
|
||||||
@ -2234,45 +2271,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ToolChoice": {
|
|
||||||
"allOf": [
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/ToolType"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"nullable": true
|
|
||||||
},
|
|
||||||
"ToolType": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "string",
|
|
||||||
"description": "Means the model can pick between generating a message or calling one or more tools.",
|
|
||||||
"enum": [
|
|
||||||
"auto"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "string",
|
|
||||||
"description": "Means the model will not call any tool and instead generates a message.",
|
|
||||||
"enum": [
|
|
||||||
"none"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"required": [
|
|
||||||
"function"
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"function": {
|
|
||||||
"$ref": "#/components/schemas/FunctionName"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"description": "Controls which (if any) tool is called by the model.",
|
|
||||||
"example": "auto"
|
|
||||||
},
|
|
||||||
"Url": {
|
"Url": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{
|
use crate::{
|
||||||
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
|
ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool,
|
||||||
ToolType,
|
Properties, Tool,
|
||||||
};
|
};
|
||||||
use serde_json::{json, Map, Value};
|
use serde_json::{json, Map, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -20,44 +20,47 @@ impl ToolGrammar {
|
|||||||
|
|
||||||
pub fn apply(
|
pub fn apply(
|
||||||
tools: Vec<Tool>,
|
tools: Vec<Tool>,
|
||||||
tool_choice: ToolChoice,
|
tool_choice: ChatCompletionToolChoiceOption,
|
||||||
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
|
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
|
||||||
// if no tools are provided, we return None
|
// if no tools are provided, we return None
|
||||||
if tools.is_empty() {
|
if tools.is_empty() {
|
||||||
return Ok((tools, None));
|
return Ok((tools, None));
|
||||||
}
|
}
|
||||||
|
|
||||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
|
||||||
|
|
||||||
let mut tools = tools.clone();
|
let mut tools = tools.clone();
|
||||||
|
|
||||||
// add the no_tool function to the tools
|
// add the no_tool function to the tools as long as we are not required to use a specific tool
|
||||||
let no_tool = Tool {
|
if tool_choice != ChatCompletionToolChoiceOption::Required {
|
||||||
r#type: "function".to_string(),
|
let no_tool = Tool {
|
||||||
function: FunctionDefinition {
|
r#type: "function".to_string(),
|
||||||
name: "no_tool".to_string(),
|
function: FunctionDefinition {
|
||||||
description: Some("Open ened response with no specific tool selected".to_string()),
|
name: "no_tool".to_string(),
|
||||||
arguments: json!({
|
description: Some(
|
||||||
"type": "object",
|
"Open ended response with no specific tool selected".to_string(),
|
||||||
"properties": {
|
),
|
||||||
"content": {
|
arguments: json!({
|
||||||
"type": "string",
|
"type": "object",
|
||||||
"description": "The response content",
|
"properties": {
|
||||||
}
|
"content": {
|
||||||
},
|
"type": "string",
|
||||||
"required": ["content"]
|
"description": "The response content",
|
||||||
}),
|
}
|
||||||
},
|
},
|
||||||
};
|
"required": ["content"]
|
||||||
tools.push(no_tool);
|
}),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
tools.push(no_tool);
|
||||||
|
}
|
||||||
|
|
||||||
// if tools are provided and no tool_choice we default to the OneOf
|
// if tools are provided and no tool_choice we default to the OneOf
|
||||||
let tools_to_use = match tool_choice {
|
let tools_to_use = match tool_choice {
|
||||||
ToolType::Function(function) => {
|
ChatCompletionToolChoiceOption::Function(function) => {
|
||||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||||
}
|
}
|
||||||
ToolType::OneOf => tools.clone(),
|
ChatCompletionToolChoiceOption::Required => tools.clone(),
|
||||||
ToolType::NoTool => return Ok((tools, None)),
|
ChatCompletionToolChoiceOption::Auto => tools.clone(),
|
||||||
|
ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)),
|
||||||
};
|
};
|
||||||
|
|
||||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||||
|
@ -892,8 +892,8 @@ pub(crate) struct ChatRequest {
|
|||||||
|
|
||||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub tool_choice: ToolChoice,
|
pub tool_choice: Option<ChatCompletionToolChoiceOption>,
|
||||||
|
|
||||||
/// Response format constraints for the generation.
|
/// Response format constraints for the generation.
|
||||||
///
|
///
|
||||||
@ -949,8 +949,18 @@ impl ChatRequest {
|
|||||||
let (inputs, grammar, using_tools) = prepare_chat_input(
|
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||||
infer,
|
infer,
|
||||||
response_format,
|
response_format,
|
||||||
tools,
|
tools.clone(),
|
||||||
tool_choice,
|
// unwrap or default (use "auto" if tools are present, and "none" if not)
|
||||||
|
tool_choice.map_or_else(
|
||||||
|
|| {
|
||||||
|
if tools.is_some() {
|
||||||
|
ChatCompletionToolChoiceOption::Auto
|
||||||
|
} else {
|
||||||
|
ChatCompletionToolChoiceOption::NoTool
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|t| t,
|
||||||
|
),
|
||||||
&tool_prompt,
|
&tool_prompt,
|
||||||
guideline,
|
guideline,
|
||||||
messages,
|
messages,
|
||||||
@ -999,22 +1009,6 @@ pub fn default_tool_prompt() -> String {
|
|||||||
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
|
||||||
#[schema(example = "auto")]
|
|
||||||
/// Controls which (if any) tool is called by the model.
|
|
||||||
pub enum ToolType {
|
|
||||||
/// Means the model can pick between generating a message or calling one or more tools.
|
|
||||||
#[schema(rename = "auto")]
|
|
||||||
OneOf,
|
|
||||||
/// Means the model will not call any tool and instead generates a message.
|
|
||||||
#[schema(rename = "none")]
|
|
||||||
NoTool,
|
|
||||||
/// Forces the model to call a specific tool.
|
|
||||||
#[schema(rename = "function")]
|
|
||||||
#[serde(alias = "function")]
|
|
||||||
Function(FunctionName),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum TypedChoice {
|
pub enum TypedChoice {
|
||||||
@ -1027,29 +1021,59 @@ pub struct FunctionName {
|
|||||||
pub name: String,
|
pub name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
|
||||||
#[serde(from = "ToolTypeDeserializer")]
|
#[serde(from = "ToolTypeDeserializer")]
|
||||||
pub struct ToolChoice(pub Option<ToolType>);
|
pub enum ChatCompletionToolChoiceOption {
|
||||||
|
/// Means the model can pick between generating a message or calling one or more tools.
|
||||||
#[derive(Deserialize)]
|
#[schema(rename = "auto")]
|
||||||
#[serde(untagged)]
|
Auto,
|
||||||
enum ToolTypeDeserializer {
|
/// Means the model will not call any tool and instead generates a message.
|
||||||
Null,
|
#[schema(rename = "none")]
|
||||||
String(String),
|
#[default]
|
||||||
ToolType(TypedChoice),
|
NoTool,
|
||||||
|
/// Means the model must call one or more tools.
|
||||||
|
#[schema(rename = "required")]
|
||||||
|
Required,
|
||||||
|
/// Forces the model to call a specific tool.
|
||||||
|
#[schema(rename = "function")]
|
||||||
|
#[serde(alias = "function")]
|
||||||
|
Function(FunctionName),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<ToolTypeDeserializer> for ToolChoice {
|
#[derive(Deserialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
/// Controls which (if any) tool is called by the model.
|
||||||
|
/// - `none` means the model will not call any tool and instead generates a message.
|
||||||
|
/// - `auto` means the model can pick between generating a message or calling one or more tools.
|
||||||
|
/// - `required` means the model must call one or more tools.
|
||||||
|
/// - Specifying a particular tool via `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}` forces the model to call that tool.
|
||||||
|
///
|
||||||
|
/// `none` is the default when no tools are present. `auto` is the default if tools are present."
|
||||||
|
enum ToolTypeDeserializer {
|
||||||
|
/// `none` means the model will not call any tool and instead generates a message.
|
||||||
|
Null,
|
||||||
|
|
||||||
|
/// `auto` means the model can pick between generating a message or calling one or more tools.
|
||||||
|
#[schema(example = "auto")]
|
||||||
|
String(String),
|
||||||
|
|
||||||
|
/// Specifying a particular tool forces the model to call that tool, with structured function details.
|
||||||
|
#[schema(example = r#"{"type": "function", "function": {"name": "my_function"}}"#)]
|
||||||
|
TypedChoice(TypedChoice),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ToolTypeDeserializer> for ChatCompletionToolChoiceOption {
|
||||||
fn from(value: ToolTypeDeserializer) -> Self {
|
fn from(value: ToolTypeDeserializer) -> Self {
|
||||||
match value {
|
match value {
|
||||||
ToolTypeDeserializer::Null => ToolChoice(None),
|
ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool,
|
||||||
ToolTypeDeserializer::String(s) => match s.as_str() {
|
ToolTypeDeserializer::String(s) => match s.as_str() {
|
||||||
"none" => ToolChoice(Some(ToolType::NoTool)),
|
"none" => ChatCompletionToolChoiceOption::NoTool,
|
||||||
"auto" => ToolChoice(Some(ToolType::OneOf)),
|
"auto" => ChatCompletionToolChoiceOption::Auto,
|
||||||
_ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))),
|
"required" => ChatCompletionToolChoiceOption::Required,
|
||||||
|
_ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }),
|
||||||
},
|
},
|
||||||
ToolTypeDeserializer::ToolType(TypedChoice::Function { function }) => {
|
ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => {
|
||||||
ToolChoice(Some(ToolType::Function(function)))
|
ChatCompletionToolChoiceOption::Function(function)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1661,20 +1685,27 @@ mod tests {
|
|||||||
fn tool_choice_formats() {
|
fn tool_choice_formats() {
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct TestRequest {
|
struct TestRequest {
|
||||||
tool_choice: ToolChoice,
|
tool_choice: ChatCompletionToolChoiceOption,
|
||||||
}
|
}
|
||||||
|
|
||||||
let none = r#"{"tool_choice":"none"}"#;
|
let none = r#"{"tool_choice":"none"}"#;
|
||||||
let de_none: TestRequest = serde_json::from_str(none).unwrap();
|
let de_none: TestRequest = serde_json::from_str(none).unwrap();
|
||||||
assert_eq!(de_none.tool_choice, ToolChoice(Some(ToolType::NoTool)));
|
assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool);
|
||||||
|
|
||||||
let auto = r#"{"tool_choice":"auto"}"#;
|
let auto = r#"{"tool_choice":"auto"}"#;
|
||||||
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
|
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
|
||||||
assert_eq!(de_auto.tool_choice, ToolChoice(Some(ToolType::OneOf)));
|
assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto);
|
||||||
|
|
||||||
let ref_choice = ToolChoice(Some(ToolType::Function(FunctionName {
|
let auto = r#"{"tool_choice":"required"}"#;
|
||||||
|
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
de_auto.tool_choice,
|
||||||
|
ChatCompletionToolChoiceOption::Required
|
||||||
|
);
|
||||||
|
|
||||||
|
let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName {
|
||||||
name: "myfn".to_string(),
|
name: "myfn".to_string(),
|
||||||
})));
|
});
|
||||||
|
|
||||||
let named = r#"{"tool_choice":"myfn"}"#;
|
let named = r#"{"tool_choice":"myfn"}"#;
|
||||||
let de_named: TestRequest = serde_json::from_str(named).unwrap();
|
let de_named: TestRequest = serde_json::from_str(named).unwrap();
|
||||||
|
@ -28,7 +28,7 @@ use crate::{
|
|||||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall};
|
||||||
use crate::{ModelInfo, ModelsInfo};
|
use crate::{ModelInfo, ModelsInfo};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
@ -1551,12 +1551,11 @@ GrammarType,
|
|||||||
Usage,
|
Usage,
|
||||||
StreamOptions,
|
StreamOptions,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
ToolType,
|
|
||||||
Tool,
|
Tool,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
Function,
|
Function,
|
||||||
FunctionDefinition,
|
FunctionDefinition,
|
||||||
ToolChoice,
|
ChatCompletionToolChoiceOption,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@ -2522,7 +2521,7 @@ pub(crate) fn prepare_chat_input(
|
|||||||
infer: &Infer,
|
infer: &Infer,
|
||||||
response_format: Option<GrammarType>,
|
response_format: Option<GrammarType>,
|
||||||
tools: Option<Vec<Tool>>,
|
tools: Option<Vec<Tool>>,
|
||||||
tool_choice: ToolChoice,
|
tool_choice: ChatCompletionToolChoiceOption,
|
||||||
tool_prompt: &str,
|
tool_prompt: &str,
|
||||||
guideline: Option<String>,
|
guideline: Option<String>,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
@ -2660,7 +2659,7 @@ mod tests {
|
|||||||
&infer,
|
&infer,
|
||||||
response_format,
|
response_format,
|
||||||
tools,
|
tools,
|
||||||
ToolChoice(None),
|
ChatCompletionToolChoiceOption::Auto,
|
||||||
tool_prompt,
|
tool_prompt,
|
||||||
guideline,
|
guideline,
|
||||||
messages,
|
messages,
|
||||||
|
Loading…
Reference in New Issue
Block a user