mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: improve, simplify and rename tool choice struct add required support and refactor
This commit is contained in:
parent
209f841767
commit
f53c8059e9
@ -921,6 +921,42 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"ChatCompletionToolChoiceOption": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Means the model can pick between generating a message or calling one or more tools.",
|
||||
"enum": [
|
||||
"auto"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Means the model will not call any tool and instead generates a message.",
|
||||
"enum": [
|
||||
"none"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Means the model must call one or more tools.",
|
||||
"enum": [
|
||||
"required"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"function"
|
||||
],
|
||||
"properties": {
|
||||
"function": {
|
||||
"$ref": "#/components/schemas/FunctionName"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"ChatCompletionTopLogprob": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
@ -1055,9 +1091,10 @@
|
||||
"tool_choice": {
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/ToolChoice"
|
||||
"$ref": "#/components/schemas/ChatCompletionToolChoiceOption"
|
||||
}
|
||||
],
|
||||
"default": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"tool_prompt": {
|
||||
@ -2234,45 +2271,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"ToolChoice": {
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/ToolType"
|
||||
}
|
||||
],
|
||||
"nullable": true
|
||||
},
|
||||
"ToolType": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Means the model can pick between generating a message or calling one or more tools.",
|
||||
"enum": [
|
||||
"auto"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Means the model will not call any tool and instead generates a message.",
|
||||
"enum": [
|
||||
"none"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"function"
|
||||
],
|
||||
"properties": {
|
||||
"function": {
|
||||
"$ref": "#/components/schemas/FunctionName"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"description": "Controls which (if any) tool is called by the model.",
|
||||
"example": "auto"
|
||||
},
|
||||
"Url": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::infer::InferError;
|
||||
use crate::{
|
||||
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
|
||||
ToolType,
|
||||
ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool,
|
||||
Properties, Tool,
|
||||
};
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
@ -20,23 +20,24 @@ impl ToolGrammar {
|
||||
|
||||
pub fn apply(
|
||||
tools: Vec<Tool>,
|
||||
tool_choice: ToolChoice,
|
||||
tool_choice: ChatCompletionToolChoiceOption,
|
||||
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
|
||||
// if no tools are provided, we return None
|
||||
if tools.is_empty() {
|
||||
return Ok((tools, None));
|
||||
}
|
||||
|
||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||
|
||||
let mut tools = tools.clone();
|
||||
|
||||
// add the no_tool function to the tools
|
||||
// add the no_tool function to the tools as long as we are not required to use a specific tool
|
||||
if tool_choice != ChatCompletionToolChoiceOption::Required {
|
||||
let no_tool = Tool {
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
name: "no_tool".to_string(),
|
||||
description: Some("Open ened response with no specific tool selected".to_string()),
|
||||
description: Some(
|
||||
"Open ended response with no specific tool selected".to_string(),
|
||||
),
|
||||
arguments: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -50,14 +51,16 @@ impl ToolGrammar {
|
||||
},
|
||||
};
|
||||
tools.push(no_tool);
|
||||
}
|
||||
|
||||
// if tools are provided and no tool_choice we default to the OneOf
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::Function(function) => {
|
||||
ChatCompletionToolChoiceOption::Function(function) => {
|
||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||
}
|
||||
ToolType::OneOf => tools.clone(),
|
||||
ToolType::NoTool => return Ok((tools, None)),
|
||||
ChatCompletionToolChoiceOption::Required => tools.clone(),
|
||||
ChatCompletionToolChoiceOption::Auto => tools.clone(),
|
||||
ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)),
|
||||
};
|
||||
|
||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||
|
@ -892,8 +892,8 @@ pub(crate) struct ChatRequest {
|
||||
|
||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub tool_choice: ToolChoice,
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub tool_choice: Option<ChatCompletionToolChoiceOption>,
|
||||
|
||||
/// Response format constraints for the generation.
|
||||
///
|
||||
@ -949,8 +949,18 @@ impl ChatRequest {
|
||||
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||
infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
tools.clone(),
|
||||
// unwrap or default (use "auto" if tools are present, and "none" if not)
|
||||
tool_choice.map_or_else(
|
||||
|| {
|
||||
if tools.is_some() {
|
||||
ChatCompletionToolChoiceOption::Auto
|
||||
} else {
|
||||
ChatCompletionToolChoiceOption::NoTool
|
||||
}
|
||||
},
|
||||
|t| t,
|
||||
),
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
@ -999,22 +1009,6 @@ pub fn default_tool_prompt() -> String {
|
||||
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||
#[schema(example = "auto")]
|
||||
/// Controls which (if any) tool is called by the model.
|
||||
pub enum ToolType {
|
||||
/// Means the model can pick between generating a message or calling one or more tools.
|
||||
#[schema(rename = "auto")]
|
||||
OneOf,
|
||||
/// Means the model will not call any tool and instead generates a message.
|
||||
#[schema(rename = "none")]
|
||||
NoTool,
|
||||
/// Forces the model to call a specific tool.
|
||||
#[schema(rename = "function")]
|
||||
#[serde(alias = "function")]
|
||||
Function(FunctionName),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum TypedChoice {
|
||||
@ -1027,29 +1021,59 @@ pub struct FunctionName {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
|
||||
#[serde(from = "ToolTypeDeserializer")]
|
||||
pub struct ToolChoice(pub Option<ToolType>);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ToolTypeDeserializer {
|
||||
Null,
|
||||
String(String),
|
||||
ToolType(TypedChoice),
|
||||
pub enum ChatCompletionToolChoiceOption {
|
||||
/// Means the model can pick between generating a message or calling one or more tools.
|
||||
#[schema(rename = "auto")]
|
||||
Auto,
|
||||
/// Means the model will not call any tool and instead generates a message.
|
||||
#[schema(rename = "none")]
|
||||
#[default]
|
||||
NoTool,
|
||||
/// Means the model must call one or more tools.
|
||||
#[schema(rename = "required")]
|
||||
Required,
|
||||
/// Forces the model to call a specific tool.
|
||||
#[schema(rename = "function")]
|
||||
#[serde(alias = "function")]
|
||||
Function(FunctionName),
|
||||
}
|
||||
|
||||
impl From<ToolTypeDeserializer> for ToolChoice {
|
||||
#[derive(Deserialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
/// Controls which (if any) tool is called by the model.
|
||||
/// - `none` means the model will not call any tool and instead generates a message.
|
||||
/// - `auto` means the model can pick between generating a message or calling one or more tools.
|
||||
/// - `required` means the model must call one or more tools.
|
||||
/// - Specifying a particular tool via `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}` forces the model to call that tool.
|
||||
///
|
||||
/// `none` is the default when no tools are present. `auto` is the default if tools are present."
|
||||
enum ToolTypeDeserializer {
|
||||
/// `none` means the model will not call any tool and instead generates a message.
|
||||
Null,
|
||||
|
||||
/// `auto` means the model can pick between generating a message or calling one or more tools.
|
||||
#[schema(example = "auto")]
|
||||
String(String),
|
||||
|
||||
/// Specifying a particular tool forces the model to call that tool, with structured function details.
|
||||
#[schema(example = r#"{"type": "function", "function": {"name": "my_function"}}"#)]
|
||||
TypedChoice(TypedChoice),
|
||||
}
|
||||
|
||||
impl From<ToolTypeDeserializer> for ChatCompletionToolChoiceOption {
|
||||
fn from(value: ToolTypeDeserializer) -> Self {
|
||||
match value {
|
||||
ToolTypeDeserializer::Null => ToolChoice(None),
|
||||
ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool,
|
||||
ToolTypeDeserializer::String(s) => match s.as_str() {
|
||||
"none" => ToolChoice(Some(ToolType::NoTool)),
|
||||
"auto" => ToolChoice(Some(ToolType::OneOf)),
|
||||
_ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))),
|
||||
"none" => ChatCompletionToolChoiceOption::NoTool,
|
||||
"auto" => ChatCompletionToolChoiceOption::Auto,
|
||||
"required" => ChatCompletionToolChoiceOption::Required,
|
||||
_ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }),
|
||||
},
|
||||
ToolTypeDeserializer::ToolType(TypedChoice::Function { function }) => {
|
||||
ToolChoice(Some(ToolType::Function(function)))
|
||||
ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => {
|
||||
ChatCompletionToolChoiceOption::Function(function)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1661,20 +1685,27 @@ mod tests {
|
||||
fn tool_choice_formats() {
|
||||
#[derive(Deserialize)]
|
||||
struct TestRequest {
|
||||
tool_choice: ToolChoice,
|
||||
tool_choice: ChatCompletionToolChoiceOption,
|
||||
}
|
||||
|
||||
let none = r#"{"tool_choice":"none"}"#;
|
||||
let de_none: TestRequest = serde_json::from_str(none).unwrap();
|
||||
assert_eq!(de_none.tool_choice, ToolChoice(Some(ToolType::NoTool)));
|
||||
assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool);
|
||||
|
||||
let auto = r#"{"tool_choice":"auto"}"#;
|
||||
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
|
||||
assert_eq!(de_auto.tool_choice, ToolChoice(Some(ToolType::OneOf)));
|
||||
assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto);
|
||||
|
||||
let ref_choice = ToolChoice(Some(ToolType::Function(FunctionName {
|
||||
let auto = r#"{"tool_choice":"required"}"#;
|
||||
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
|
||||
assert_eq!(
|
||||
de_auto.tool_choice,
|
||||
ChatCompletionToolChoiceOption::Required
|
||||
);
|
||||
|
||||
let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName {
|
||||
name: "myfn".to_string(),
|
||||
})));
|
||||
});
|
||||
|
||||
let named = r#"{"tool_choice":"myfn"}"#;
|
||||
let de_named: TestRequest = serde_json::from_str(named).unwrap();
|
||||
|
@ -28,7 +28,7 @@ use crate::{
|
||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||
};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||
use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall};
|
||||
use crate::{ModelInfo, ModelsInfo};
|
||||
use async_stream::__private::AsyncStream;
|
||||
use axum::extract::Extension;
|
||||
@ -1551,12 +1551,11 @@ GrammarType,
|
||||
Usage,
|
||||
StreamOptions,
|
||||
DeltaToolCall,
|
||||
ToolType,
|
||||
Tool,
|
||||
ToolCall,
|
||||
Function,
|
||||
FunctionDefinition,
|
||||
ToolChoice,
|
||||
ChatCompletionToolChoiceOption,
|
||||
ModelInfo,
|
||||
)
|
||||
),
|
||||
@ -2522,7 +2521,7 @@ pub(crate) fn prepare_chat_input(
|
||||
infer: &Infer,
|
||||
response_format: Option<GrammarType>,
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: ToolChoice,
|
||||
tool_choice: ChatCompletionToolChoiceOption,
|
||||
tool_prompt: &str,
|
||||
guideline: Option<String>,
|
||||
messages: Vec<Message>,
|
||||
@ -2660,7 +2659,7 @@ mod tests {
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
ToolChoice(None),
|
||||
ChatCompletionToolChoiceOption::Auto,
|
||||
tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
|
Loading…
Reference in New Issue
Block a user