diff --git a/router/src/lib.rs b/router/src/lib.rs index 566335eb..98424497 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -591,44 +591,39 @@ mod deserialize_tool_choice { } } -#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)] +#[derive(Debug, Deserialize, Serialize, ToSchema)] pub struct Tools { - // rename to "$function" to avoid conflicts with other fields - #[serde(rename = "$function")] - pub function: std::collections::HashMap, - pub any_of: Vec, + #[serde(flatten)] + functions_map: FunctionsMap, + properties: Properties, } -// Allows Tools to be converted to a valid JSON schema object -impl From for serde_json::Value { - fn from(tools: Tools) -> Self { - let mut map = serde_json::Map::new(); - let mut functions = serde_json::Map::new(); - for (name, value) in tools.function { - functions.insert(name, value); - } - map.insert("$functions".to_string(), serde_json::json!(functions)); - let mut properties = serde_json::Map::new(); - let mut function = serde_json::Map::new(); - function.insert("anyOf".to_string(), serde_json::json!(tools.any_of)); - properties.insert("function".to_string(), serde_json::json!(function)); - map.insert("properties".to_string(), serde_json::json!(properties)); - serde_json::Value::Object(map) - } +#[derive(Debug, Serialize, Deserialize)] +struct FunctionsMap { + #[serde(rename = "$functions")] + functions: std::collections::HashMap, } -#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] -pub struct FunctionRef { +#[derive(Debug, Serialize, Deserialize)] +struct FunctionRef { #[serde(rename = "$ref")] - pub _ref: String, + ref_path: String, } -impl FunctionRef { - pub fn new(name: &str) -> Self { - Self { - _ref: format!("#/$functions/{}", name), - } - } +#[derive(Debug, Serialize, Deserialize)] +struct Properties { + #[serde(serialize_with = "serialize_function")] + function: Vec, +} + +fn serialize_function(functions: &Vec, serializer: S) -> Result +where + S: serde::Serializer, +{ + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("Function", 1)?; + state.serialize_field("anyOf", functions)?; + state.end() } #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 09f7c3e7..e3254625 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -10,7 +10,7 @@ use crate::{ HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, }; -use crate::{Function, FunctionRef, ToolCall, ToolType, Tools}; +use crate::{Function, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; @@ -629,11 +629,15 @@ async fn chat_completions( .collect(); let tools = Tools { - function: functions, - any_of: tools_to_use - .iter() - .map(|tool| FunctionRef::new(&tool.function.name)) - .collect(), + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .collect(), + }, }; let tools_str = serde_json::to_string(&tools).map_err(|e| { @@ -646,7 +650,7 @@ async fn chat_completions( ) })?; inputs = format!("{inputs}{tool_prompt}{tools_str}"); - Some(GrammarType::Json(tools.into())) + Some(GrammarType::Json(serde_json::json!(tools))) } else { None };