add OpenAI like tool_choice for named choice

This commit is contained in:
Linus Bierhoff 2024-10-10 18:50:32 +02:00 committed by drbh
parent 98330df65e
commit f979ff1965
2 changed files with 30 additions and 0 deletions

View File

@ -2268,6 +2268,24 @@
"$ref": "#/components/schemas/FunctionName" "$ref": "#/components/schemas/FunctionName"
} }
} }
},
{
"type": "object",
"required": [
"type",
"function"
],
"properties": {
"type": {
"type": "string",
"enum": [
"function"
]
},
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
} }
], ],
"description": "Controls which (if any) tool is called by the model.", "description": "Controls which (if any) tool is called by the model.",

View File

@ -1011,9 +1011,18 @@ pub enum ToolType {
NoTool, NoTool,
/// Forces the model to call a specific tool. /// Forces the model to call a specific tool.
#[schema(rename = "function")] #[schema(rename = "function")]
#[serde(alias = "function")]
Function(FunctionName), Function(FunctionName),
} }
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(tag = "type")]
pub enum TypedChoice {
#[serde(rename = "function")]
Function{function: FunctionName},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
pub struct FunctionName { pub struct FunctionName {
pub name: String, pub name: String,
@ -1029,8 +1038,10 @@ enum ToolTypeDeserializer {
Null, Null,
String(String), String(String),
ToolType(ToolType), ToolType(ToolType),
TypedChoice(TypedChoice) //this is the OpenAI schema
} }
impl From<ToolTypeDeserializer> for ToolChoice { impl From<ToolTypeDeserializer> for ToolChoice {
fn from(value: ToolTypeDeserializer) -> Self { fn from(value: ToolTypeDeserializer) -> Self {
match value { match value {
@ -1040,6 +1051,7 @@ impl From<ToolTypeDeserializer> for ToolChoice {
"auto" => ToolChoice(Some(ToolType::OneOf)), "auto" => ToolChoice(Some(ToolType::OneOf)),
_ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))), _ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))),
}, },
ToolTypeDeserializer::TypedChoice(TypedChoice::Function{function}) => ToolChoice(Some(ToolType::Function(function))),
ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
} }
} }