Update ToolType input schema

This commit is contained in:
Wauplin 2024-10-02 12:01:10 +02:00
parent 5b6b74e21d
commit aae6db9cd0
No known key found for this signature in database
GPG Key ID: 9838FE02BECE1A02
3 changed files with 26 additions and 18 deletions

View File

@ -2114,12 +2114,18 @@
"ToolType": {
"oneOf": [
{
"type": "object",
"default": null,
"nullable": true
"type": "string",
"description": "Means the model can pick between generating a message or calling one or more tools.",
"enum": [
"auto"
]
},
{
"type": "string"
"type": "string",
"description": "Means the model will not call any tool and instead generates a message.",
"enum": [
"none"
]
},
{
"type": "object",
@ -2131,13 +2137,10 @@
"$ref": "#/components/schemas/FunctionName"
}
}
},
{
"type": "object",
"default": null,
"nullable": true
}
]
],
"description": "Controls which (if any) tool is called by the model.",
"example": "auto"
},
"Url": {
"type": "object",

View File

@ -53,10 +53,7 @@ impl ToolGrammar {
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![Self::find_tool_by_name(&tools, &name)?]
}
ToolType::Function { function } => {
ToolType::Function ( function ) => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools.clone(),

View File

@ -957,12 +957,18 @@ pub fn default_tool_prompt() -> String {
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
#[serde(untagged)]
#[schema(example = "auto")]
/// Controls which (if any) tool is called by the model.
pub enum ToolType {
/// Means the model can pick between generating a message or calling one or more tools.
#[schema(rename = "auto")]
OneOf,
FunctionName(String),
Function { function: FunctionName },
/// Means the model will not call any tool and instead generates a message.
#[schema(rename = "none")]
NoTool,
/// Forces the model to call a specific tool.
#[schema(rename = "function")]
Function(FunctionName),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
@ -987,7 +993,9 @@ impl From<ToolTypeDeserializer> for ToolChoice {
ToolTypeDeserializer::String(s) => match s.as_str() {
"none" => ToolChoice(Some(ToolType::NoTool)),
"auto" => ToolChoice(Some(ToolType::OneOf)),
_ => ToolChoice(Some(ToolType::FunctionName(s))),
_ => ToolChoice(Some(ToolType::Function(FunctionName {
name: s
} ))),
},
ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
}