mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: improve tool choice syntax and response parsing/errors
This commit is contained in:
parent
35f8a88db5
commit
21dc6776b1
@ -7,7 +7,7 @@ pub(crate) use health::HealthCheck;
|
||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||
use crate::{
|
||||
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token,
|
||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice,
|
||||
};
|
||||
use crate::{
|
||||
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
|
||||
@ -332,35 +332,37 @@ impl ChatTemplate {
|
||||
pub struct ToolGrammar {}
|
||||
|
||||
impl ToolGrammar {
|
||||
// find a tool by name
|
||||
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
|
||||
tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == name)
|
||||
.cloned()
|
||||
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
|
||||
}
|
||||
|
||||
pub fn apply(
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: Option<ToolType>,
|
||||
tool_choice: ToolChoice,
|
||||
) -> Result<Option<Tools>, InferError> {
|
||||
if let Some(req_tools) = tools {
|
||||
let tool_choice = tool_choice
|
||||
.map(|t| match t {
|
||||
ToolType::FunctionName(name) if name == "auto" => ToolType::OneOf,
|
||||
_ => t,
|
||||
})
|
||||
.unwrap_or_default();
|
||||
// if no tools are provided, we return None
|
||||
let tools = match tools {
|
||||
Some(tools) if !tools.is_empty() => tools,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||
|
||||
// if tools are provided and no tool_choice we default to the OneOf
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *name)
|
||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
||||
.clone()]
|
||||
vec![Self::find_tool_by_name(&tools, &name)?]
|
||||
}
|
||||
ToolType::Function { function } => {
|
||||
let tool = req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == function.name)
|
||||
.unwrap_or_else(|| panic!("Tool with name {} not found", function.name))
|
||||
.clone();
|
||||
vec![tool]
|
||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
ToolType::OneOf => tools,
|
||||
ToolType::NoTool => return Ok(None),
|
||||
};
|
||||
|
||||
// adds the error notification function for LLM feedback if required
|
||||
@ -454,10 +456,7 @@ impl ToolGrammar {
|
||||
},
|
||||
};
|
||||
|
||||
return Ok(Some(tools));
|
||||
}
|
||||
// Err(InferError::ToolError("No tools provided".to_string()))
|
||||
Ok(None)
|
||||
Ok(Some(tools))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -824,7 +824,7 @@ pub(crate) struct ChatRequest {
|
||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub tool_choice: Option<ToolType>,
|
||||
pub tool_choice: ToolChoice,
|
||||
|
||||
/// Response format constraints for the generation.
|
||||
///
|
||||
@ -840,16 +840,13 @@ fn default_tool_prompt() -> Option<String> {
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolType {
|
||||
#[default]
|
||||
#[serde(alias = "auto")]
|
||||
OneOf,
|
||||
FunctionName(String),
|
||||
Function {
|
||||
function: FunctionName,
|
||||
},
|
||||
Function { function: FunctionName },
|
||||
NoTool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
|
||||
@ -857,27 +854,26 @@ pub struct FunctionName {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||
#[serde(from = "ToolTypeDeserializer")]
|
||||
pub struct ToolChoice(pub Option<ToolType>);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ToolTypeDeserializer {
|
||||
None(Option<String>),
|
||||
Some(ToolType),
|
||||
String(String),
|
||||
ToolType(ToolType),
|
||||
}
|
||||
|
||||
impl From<ToolTypeDeserializer> for ToolChoice {
|
||||
fn from(value: ToolTypeDeserializer) -> Self {
|
||||
match value {
|
||||
ToolTypeDeserializer::None(opt) => match opt.as_deref() {
|
||||
Some("none") => ToolChoice(None),
|
||||
Some("auto") => ToolChoice(Some(ToolType::OneOf)),
|
||||
Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))),
|
||||
None => ToolChoice(Some(ToolType::OneOf)),
|
||||
ToolTypeDeserializer::String(s) => match s.as_str() {
|
||||
"none" => ToolChoice(Some(ToolType::NoTool)),
|
||||
"auto" => ToolChoice(Some(ToolType::OneOf)),
|
||||
_ => ToolChoice(Some(ToolType::FunctionName(s))),
|
||||
},
|
||||
ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)),
|
||||
ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1375,4 +1371,47 @@ mod tests {
|
||||
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":{"format":"csv"}}}]}"#
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn tool_deserialize() {
|
||||
// Test ToolCall deserialization
|
||||
let json = r#"{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":{"format":"csv"}}}"#;
|
||||
let tool: ToolCall = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
tool,
|
||||
ToolCall {
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name: "myfn".to_string(),
|
||||
arguments: json!({
|
||||
"format": "csv"
|
||||
}),
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
// Test ToolChoice deserialization with "auto"
|
||||
let auto_json = r#""auto""#;
|
||||
let auto_choice: ToolChoice = serde_json::from_str(auto_json).unwrap();
|
||||
assert_eq!(auto_choice, ToolChoice(Some(ToolType::OneOf)));
|
||||
|
||||
// Test ToolChoice deserialization with "none"
|
||||
let none_json = r#""none""#;
|
||||
let none_choice: ToolChoice = serde_json::from_str(none_json).unwrap();
|
||||
assert_eq!(none_choice, ToolChoice(None));
|
||||
|
||||
// Test ToolChoice deserialization with a specific function name
|
||||
let function_json = r#""my_function""#;
|
||||
let function_choice: ToolChoice = serde_json::from_str(function_json).unwrap();
|
||||
assert_eq!(
|
||||
function_choice,
|
||||
ToolChoice(Some(ToolType::FunctionName("my_function".to_string())))
|
||||
);
|
||||
|
||||
// Test ToolChoice deserialization with no value (should default to OneOf)
|
||||
let default_json = r#"null"#;
|
||||
let default_choice: ToolChoice = serde_json::from_str(default_json).unwrap();
|
||||
assert_eq!(default_choice, ToolChoice(Some(ToolType::OneOf)));
|
||||
}
|
||||
}
|
||||
|
@ -1192,39 +1192,33 @@ async fn chat_completions(
|
||||
.as_secs();
|
||||
|
||||
let (tool_calls, output) = if tool_grammar.is_some() {
|
||||
// gen_text should be valid json
|
||||
let gen_text_value: Value =
|
||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
let gen_text_value: Value = serde_json::from_str(&generation.generated_text)
|
||||
.map_err(|e| InferError::ToolError(e.to_string()))?;
|
||||
|
||||
let function = gen_text_value.get("function").ok_or(InferError::ToolError(
|
||||
"No function found in generated text".to_string(),
|
||||
))?;
|
||||
|
||||
let name = function
|
||||
.get("_name")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or(InferError::ToolError(
|
||||
"No _name found in generated text".to_string(),
|
||||
))?
|
||||
.to_string();
|
||||
|
||||
let mut arguments = function.clone();
|
||||
if let Value::Object(ref mut props) = arguments {
|
||||
props.remove("_name");
|
||||
}
|
||||
|
||||
let tool_calls = vec![ToolCall {
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name: gen_text_value
|
||||
.get("function")
|
||||
.and_then(|f| f.get("_name"))
|
||||
.and_then(|name| name.as_str())
|
||||
.unwrap_or("default_function_name")
|
||||
.to_string(),
|
||||
// Serialize the JSON object obtained from "function" to an escaped JSON string
|
||||
arguments: gen_text_value
|
||||
.get("function")
|
||||
.map(|f| {
|
||||
let mut f_cloned = f.clone();
|
||||
if let Value::Object(ref mut props) = f_cloned {
|
||||
props.remove("_name");
|
||||
}
|
||||
f_cloned
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
name,
|
||||
arguments,
|
||||
},
|
||||
}];
|
||||
(Some(tool_calls), None)
|
||||
|
Loading…
Reference in New Issue
Block a user