diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index f3b10450..efbaf480 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -336,8 +336,14 @@ impl ToolGrammar { tools: Option>, tool_choice: Option, ) -> Result, InferError> { - if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { - // let tool_prompt = tool_prompt.unwrap_or_default(); + if let Some(req_tools) = tools { + let tool_choice = tool_choice + .map(|t| match t { + ToolType::FunctionName(name) if name == "auto" => ToolType::OneOf, + _ => t, + }) + .unwrap_or_default(); + let tools_to_use = match tool_choice { ToolType::FunctionName(name) => { vec![req_tools diff --git a/router/src/lib.rs b/router/src/lib.rs index f856406d..be02e645 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -840,12 +840,16 @@ fn default_tool_prompt() -> Option { ) } -#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] +#[derive(Clone, Default, Debug, Deserialize, PartialEq, Serialize, ToSchema)] #[serde(untagged)] pub enum ToolType { + #[default] + #[serde(alias = "auto")] OneOf, FunctionName(String), - Function { function: FunctionName }, + Function { + function: FunctionName, + }, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]