fix: consolidate changes and remove old tool type

This commit is contained in:
David Holtz 2024-10-14 16:44:54 +00:00 committed by drbh
parent 2c172a2da7
commit 209f841767
2 changed files with 15 additions and 18 deletions

View File

@ -2245,12 +2245,18 @@
"ToolType": { "ToolType": {
"oneOf": [ "oneOf": [
{ {
"type": "object", "type": "string",
"default": null, "description": "Means the model can pick between generating a message or calling one or more tools.",
"nullable": true "enum": [
"auto"
]
}, },
{ {
"type": "string" "type": "string",
"description": "Means the model will not call any tool and instead generates a message.",
"enum": [
"none"
]
}, },
{ {
"type": "object", "type": "object",
@ -2262,13 +2268,10 @@
"$ref": "#/components/schemas/FunctionName" "$ref": "#/components/schemas/FunctionName"
} }
} }
},
{
"type": "object",
"default": null,
"nullable": true
} }
] ],
"description": "Controls which (if any) tool is called by the model.",
"example": "auto"
}, },
"Url": { "Url": {
"type": "object", "type": "object",

View File

@ -1036,8 +1036,7 @@ pub struct ToolChoice(pub Option<ToolType>);
enum ToolTypeDeserializer { enum ToolTypeDeserializer {
Null, Null,
String(String), String(String),
ToolType(ToolType), ToolType(TypedChoice),
TypedChoice(TypedChoice), //this is the OpenAI schema
} }
impl From<ToolTypeDeserializer> for ToolChoice { impl From<ToolTypeDeserializer> for ToolChoice {
@ -1049,10 +1048,9 @@ impl From<ToolTypeDeserializer> for ToolChoice {
"auto" => ToolChoice(Some(ToolType::OneOf)), "auto" => ToolChoice(Some(ToolType::OneOf)),
_ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))), _ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))),
}, },
ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => { ToolTypeDeserializer::ToolType(TypedChoice::Function { function }) => {
ToolChoice(Some(ToolType::Function(function))) ToolChoice(Some(ToolType::Function(function)))
} }
ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
} }
} }
} }
@ -1682,10 +1680,6 @@ mod tests {
let de_named: TestRequest = serde_json::from_str(named).unwrap(); let de_named: TestRequest = serde_json::from_str(named).unwrap();
assert_eq!(de_named.tool_choice, ref_choice); assert_eq!(de_named.tool_choice, ref_choice);
let old_named = r#"{"tool_choice":{"function":{"name":"myfn"}}}"#;
let de_old_named: TestRequest = serde_json::from_str(old_named).unwrap();
assert_eq!(de_old_named.tool_choice, ref_choice);
let openai_named = r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#; let openai_named = r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#;
let de_openai_named: TestRequest = serde_json::from_str(openai_named).unwrap(); let de_openai_named: TestRequest = serde_json::from_str(openai_named).unwrap();