mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
fix: adjust default tool choice (#2244)
* fix: adjust default tool choice * feat: improve tool choice syntax and response parsing/errors * fix: remove dev tests * feat: add ToolChoice to docs
This commit is contained in:
parent
40f5dc3ed6
commit
68a9685f1b
@ -909,7 +909,7 @@
|
|||||||
"tool_choice": {
|
"tool_choice": {
|
||||||
"allOf": [
|
"allOf": [
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/ToolType"
|
"$ref": "#/components/schemas/ToolChoice"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"nullable": true
|
"nullable": true
|
||||||
@ -2035,6 +2035,14 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"ToolChoice": {
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ToolType"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
"ToolType": {
|
"ToolType": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
@ -2055,6 +2063,11 @@
|
|||||||
"$ref": "#/components/schemas/FunctionName"
|
"$ref": "#/components/schemas/FunctionName"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"default": null,
|
||||||
|
"nullable": true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -7,7 +7,7 @@ pub(crate) use health::HealthCheck;
|
|||||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
||||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token,
|
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
|
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
|
||||||
@ -332,29 +332,37 @@ impl ChatTemplate {
|
|||||||
pub struct ToolGrammar {}
|
pub struct ToolGrammar {}
|
||||||
|
|
||||||
impl ToolGrammar {
|
impl ToolGrammar {
|
||||||
|
// find a tool by name
|
||||||
|
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.find(|tool| tool.function.name == name)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn apply(
|
pub fn apply(
|
||||||
tools: Option<Vec<Tool>>,
|
tools: Option<Vec<Tool>>,
|
||||||
tool_choice: Option<ToolType>,
|
tool_choice: ToolChoice,
|
||||||
) -> Result<Option<Tools>, InferError> {
|
) -> Result<Option<Tools>, InferError> {
|
||||||
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
// if no tools are provided, we return None
|
||||||
// let tool_prompt = tool_prompt.unwrap_or_default();
|
let tools = match tools {
|
||||||
|
Some(tools) if !tools.is_empty() => tools,
|
||||||
|
_ => return Ok(None),
|
||||||
|
};
|
||||||
|
|
||||||
|
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||||
|
|
||||||
|
// if tools are provided and no tool_choice we default to the OneOf
|
||||||
let tools_to_use = match tool_choice {
|
let tools_to_use = match tool_choice {
|
||||||
ToolType::FunctionName(name) => {
|
ToolType::FunctionName(name) => {
|
||||||
vec![req_tools
|
vec![Self::find_tool_by_name(&tools, &name)?]
|
||||||
.iter()
|
|
||||||
.find(|tool| tool.function.name == *name)
|
|
||||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
|
||||||
.clone()]
|
|
||||||
}
|
}
|
||||||
ToolType::Function { function } => {
|
ToolType::Function { function } => {
|
||||||
let tool = req_tools
|
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||||
.iter()
|
|
||||||
.find(|tool| tool.function.name == function.name)
|
|
||||||
.unwrap_or_else(|| panic!("Tool with name {} not found", function.name))
|
|
||||||
.clone();
|
|
||||||
vec![tool]
|
|
||||||
}
|
}
|
||||||
ToolType::OneOf => req_tools.to_owned(),
|
ToolType::OneOf => tools,
|
||||||
|
ToolType::NoTool => return Ok(None),
|
||||||
};
|
};
|
||||||
|
|
||||||
// adds the error notification function for LLM feedback if required
|
// adds the error notification function for LLM feedback if required
|
||||||
@ -448,10 +456,7 @@ impl ToolGrammar {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
return Ok(Some(tools));
|
Ok(Some(tools))
|
||||||
}
|
|
||||||
// Err(InferError::ToolError("No tools provided".to_string()))
|
|
||||||
Ok(None)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -826,7 +826,7 @@ pub(crate) struct ChatRequest {
|
|||||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
pub tool_choice: Option<ToolType>,
|
pub tool_choice: ToolChoice,
|
||||||
|
|
||||||
/// Response format constraints for the generation.
|
/// Response format constraints for the generation.
|
||||||
///
|
///
|
||||||
@ -848,6 +848,7 @@ pub enum ToolType {
|
|||||||
OneOf,
|
OneOf,
|
||||||
FunctionName(String),
|
FunctionName(String),
|
||||||
Function { function: FunctionName },
|
Function { function: FunctionName },
|
||||||
|
NoTool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
|
||||||
@ -855,27 +856,26 @@ pub struct FunctionName {
|
|||||||
pub name: String,
|
pub name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)]
|
||||||
#[serde(from = "ToolTypeDeserializer")]
|
#[serde(from = "ToolTypeDeserializer")]
|
||||||
pub struct ToolChoice(pub Option<ToolType>);
|
pub struct ToolChoice(pub Option<ToolType>);
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
enum ToolTypeDeserializer {
|
enum ToolTypeDeserializer {
|
||||||
None(Option<String>),
|
String(String),
|
||||||
Some(ToolType),
|
ToolType(ToolType),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<ToolTypeDeserializer> for ToolChoice {
|
impl From<ToolTypeDeserializer> for ToolChoice {
|
||||||
fn from(value: ToolTypeDeserializer) -> Self {
|
fn from(value: ToolTypeDeserializer) -> Self {
|
||||||
match value {
|
match value {
|
||||||
ToolTypeDeserializer::None(opt) => match opt.as_deref() {
|
ToolTypeDeserializer::String(s) => match s.as_str() {
|
||||||
Some("none") => ToolChoice(None),
|
"none" => ToolChoice(Some(ToolType::NoTool)),
|
||||||
Some("auto") => ToolChoice(Some(ToolType::OneOf)),
|
"auto" => ToolChoice(Some(ToolType::OneOf)),
|
||||||
Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))),
|
_ => ToolChoice(Some(ToolType::FunctionName(s))),
|
||||||
None => ToolChoice(Some(ToolType::OneOf)),
|
|
||||||
},
|
},
|
||||||
ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)),
|
ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -24,7 +24,7 @@ use crate::{
|
|||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
||||||
VertexResponse,
|
VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
@ -1192,39 +1192,33 @@ async fn chat_completions(
|
|||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
let (tool_calls, output) = if tool_grammar.is_some() {
|
let (tool_calls, output) = if tool_grammar.is_some() {
|
||||||
// gen_text should be valid json
|
let gen_text_value: Value = serde_json::from_str(&generation.generated_text)
|
||||||
let gen_text_value: Value =
|
.map_err(|e| InferError::ToolError(e.to_string()))?;
|
||||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
|
||||||
(
|
let function = gen_text_value.get("function").ok_or(InferError::ToolError(
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
"No function found in generated text".to_string(),
|
||||||
Json(ErrorResponse {
|
))?;
|
||||||
error: e.to_string(),
|
|
||||||
error_type: "Input validation error".to_string(),
|
let name = function
|
||||||
}),
|
.get("_name")
|
||||||
)
|
.and_then(Value::as_str)
|
||||||
})?;
|
.ok_or(InferError::ToolError(
|
||||||
|
"No _name found in generated text".to_string(),
|
||||||
|
))?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let mut arguments = function.clone();
|
||||||
|
if let Value::Object(ref mut props) = arguments {
|
||||||
|
props.remove("_name");
|
||||||
|
}
|
||||||
|
|
||||||
let tool_calls = vec![ToolCall {
|
let tool_calls = vec![ToolCall {
|
||||||
id: "0".to_string(),
|
id: "0".to_string(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: FunctionDefinition {
|
function: FunctionDefinition {
|
||||||
description: None,
|
description: None,
|
||||||
name: gen_text_value
|
name,
|
||||||
.get("function")
|
arguments,
|
||||||
.and_then(|f| f.get("_name"))
|
|
||||||
.and_then(|name| name.as_str())
|
|
||||||
.unwrap_or("default_function_name")
|
|
||||||
.to_string(),
|
|
||||||
// Serialize the JSON object obtained from "function" to an escaped JSON string
|
|
||||||
arguments: gen_text_value
|
|
||||||
.get("function")
|
|
||||||
.map(|f| {
|
|
||||||
let mut f_cloned = f.clone();
|
|
||||||
if let Value::Object(ref mut props) = f_cloned {
|
|
||||||
props.remove("_name");
|
|
||||||
}
|
|
||||||
f_cloned
|
|
||||||
})
|
|
||||||
.unwrap_or_default(),
|
|
||||||
},
|
},
|
||||||
}];
|
}];
|
||||||
(Some(tool_calls), None)
|
(Some(tool_calls), None)
|
||||||
@ -1498,6 +1492,7 @@ pub async fn run(
|
|||||||
ToolCall,
|
ToolCall,
|
||||||
Function,
|
Function,
|
||||||
FunctionDefinition,
|
FunctionDefinition,
|
||||||
|
ToolChoice,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
tags(
|
tags(
|
||||||
|
Loading…
Reference in New Issue
Block a user