diff --git a/router/src/lib.rs b/router/src/lib.rs index 67a14d5c..3365bf98 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -528,30 +528,53 @@ pub(crate) struct ChatRequest { pub tools: Option>, } -// TODO: define and use better types for tools +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)] +pub struct Tools { + // rename to "$function" to avoid conflicts with other fields + #[serde(rename = "$function")] + pub function: std::collections::HashMap, + pub any_of: Vec, +} -// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] -// enum ToolType { -// #[serde(rename = "function")] -// Function, -// } +// add traut to convert to serde_json::Value for tools +impl From for serde_json::Value { + fn from(tools: Tools) -> Self { + println!("tools: {:?}", tools); + let mut map = serde_json::Map::new(); + let mut functions = serde_json::Map::new(); + for (name, value) in tools.function { + functions.insert(name, value); + } + map.insert("$functions".to_string(), serde_json::json!(functions)); + let mut properties = serde_json::Map::new(); + let mut function = serde_json::Map::new(); + function.insert("anyOf".to_string(), serde_json::json!(tools.any_of)); + properties.insert("function".to_string(), serde_json::json!(function)); + map.insert("properties".to_string(), serde_json::json!(properties)); + serde_json::Value::Object(map) + } +} -// impl Default for ToolType { -// fn default() -> Self { -// ToolType::Function -// } -// } +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +pub struct FunctionRef { + #[serde(rename = "$ref")] + pub _ref: String, +} -// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] -// pub(crate) struct Function { -// pub description: String, -// pub name: String, -// #[serde( -// rename = "json", -// deserialize_with = "json_object_or_string_to_string::deserialize" -// )] -// pub parameters: String, -// } +impl FunctionRef { + pub fn new(name: &str) -> Self { + Self { + _ref: format!("#/$functions/{}", name), + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +pub(crate) struct Function { + pub description: String, + pub name: String, + pub parameters: serde_json::Value, +} #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] pub(crate) struct Tool { @@ -559,7 +582,7 @@ pub(crate) struct Tool { #[schema(example = "function")] pub r#type: String, // Grab the tool as generic JSON for debugging purposes. - pub function: serde_json::Value, + pub function: Function, } #[derive(Clone, Serialize, Deserialize)] diff --git a/router/src/server.rs b/router/src/server.rs index 8f40ee2a..e4764ff4 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -10,6 +10,7 @@ use crate::{ HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, }; +use crate::{FunctionRef, Tools}; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; @@ -22,6 +23,8 @@ use futures::stream::StreamExt; use futures::Stream; use futures::TryStreamExt; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; +use serde_json::Value; +use std::collections::HashMap; use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; @@ -580,45 +583,6 @@ async fn chat_completions( let logprobs = req.logprobs.unwrap_or(false); let seed = req.seed; - // Build a new JSON schema that defines the "$functions" object - // and requires the grammar to choose anyOf the functions defined. - let mut tools = serde_json::json!({}); - - // First decompose the tools and use the function name as the key - // and the parameters as the value in the "$functions" object. - if let Some(req_tools) = &req.tools { - for tool in req_tools { - let func = tool.function.clone(); - let name = func.get("name").unwrap().as_str().unwrap(); - let parameters = func.get("parameters").unwrap().as_object().unwrap().clone(); - // add a entry to the "$functions" object - tools["$functions"][name] = serde_json::Value::Object(parameters); - } - - // now add the properties to the root object - tools["properties"]["function"]["anyOf"] = serde_json::Value::Array( - req.tools - .as_ref() - .unwrap() - .iter() - // map each tool to a $ref to the function - .map(|tool| { - let func = tool.function.clone(); - let name = func.get("name").unwrap().as_str().unwrap(); - serde_json::json!({ - "$ref": format!("#/$functions/{}", name) - }) - }) - .collect(), - ); - } - - // only add grammar if tools are present - let grammar = match req.tools { - Some(_grammar) => Some(crate::GrammarType::Json(tools.to_string())), - None => None, - }; - // apply chat template to flatten the request into a single input let mut inputs = match infer.apply_chat_template(req.messages) { Ok(inputs) => inputs, @@ -635,10 +599,60 @@ async fn chat_completions( } }; - // append the tools to the inputs with TOOL prompt - let tool_prompt = - "Based on the conversation, please choose the most appropriate tool to use:".to_string(); - inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools}\n\n"); + // if theres a tools object, we need to decompose it and use the function name as the key + // and the parameters as the value in the "$functions" object. + let grammar = if let Some(req_tools) = &req.tools { + let functions: HashMap = { + let mut tools = HashMap::new(); + for tool in req_tools { + let func = tool.function.clone(); + let name = func.name; + let parameters = match func.parameters.as_object() { + Some(parameters) => parameters.clone(), + None => { + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Input validation error".to_string(), + error_type: "Input validation error".to_string(), + }), + )) + } + }; + + tools.insert(name, Value::Object(parameters)); + } + tools + }; + + let tools = Tools { + function: functions, + any_of: req_tools + .iter() + .map(|tool| FunctionRef::new(&tool.function.name)) + .collect(), + }; + + // update the input + let tools_str = serde_json::to_string(&tools).map_err(|e| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: e.to_string(), + error_type: "Input validation error".to_string(), + }), + ) + })?; + + let tool_prompt = + "Based on the conversation, please choose the most appropriate tool to use:" + .to_string(); + inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools_str}\n\n"); + + Some(GrammarType::Json(tools.into())) + } else { + None + }; // build the request passing some parameters let generate_request = GenerateRequest {