diff --git a/router/src/lib.rs b/router/src/lib.rs index 1c06eb8a..67a14d5c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -520,6 +520,46 @@ pub(crate) struct ChatRequest { #[serde(default)] #[schema(nullable = true, example = 0.95)] pub top_p: Option, + + /// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of + /// functions the model may generate JSON inputs for. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub tools: Option>, +} + +// TODO: define and use better types for tools + +// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +// enum ToolType { +// #[serde(rename = "function")] +// Function, +// } + +// impl Default for ToolType { +// fn default() -> Self { +// ToolType::Function +// } +// } + +// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +// pub(crate) struct Function { +// pub description: String, +// pub name: String, +// #[serde( +// rename = "json", +// deserialize_with = "json_object_or_string_to_string::deserialize" +// )] +// pub parameters: String, +// } + +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +pub(crate) struct Tool { + // The type of the tool. Currently, only 'function' is supported. + #[schema(example = "function")] + pub r#type: String, + // Grab the tool as generic JSON for debugging purposes. + pub function: serde_json::Value, } #[derive(Clone, Serialize, Deserialize)] diff --git a/router/src/server.rs b/router/src/server.rs index 9fdd66cc..8f40ee2a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -580,8 +580,47 @@ async fn chat_completions( let logprobs = req.logprobs.unwrap_or(false); let seed = req.seed; + // Build a new JSON schema that defines the "$functions" object + // and requires the grammar to choose anyOf the functions defined. + let mut tools = serde_json::json!({}); + + // First decompose the tools and use the function name as the key + // and the parameters as the value in the "$functions" object. + if let Some(req_tools) = &req.tools { + for tool in req_tools { + let func = tool.function.clone(); + let name = func.get("name").unwrap().as_str().unwrap(); + let parameters = func.get("parameters").unwrap().as_object().unwrap().clone(); + // add a entry to the "$functions" object + tools["$functions"][name] = serde_json::Value::Object(parameters); + } + + // now add the properties to the root object + tools["properties"]["function"]["anyOf"] = serde_json::Value::Array( + req.tools + .as_ref() + .unwrap() + .iter() + // map each tool to a $ref to the function + .map(|tool| { + let func = tool.function.clone(); + let name = func.get("name").unwrap().as_str().unwrap(); + serde_json::json!({ + "$ref": format!("#/$functions/{}", name) + }) + }) + .collect(), + ); + } + + // only add grammar if tools are present + let grammar = match req.tools { + Some(_grammar) => Some(crate::GrammarType::Json(tools.to_string())), + None => None, + }; + // apply chat template to flatten the request into a single input - let inputs = match infer.apply_chat_template(req.messages) { + let mut inputs = match infer.apply_chat_template(req.messages) { Ok(inputs) => inputs, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); @@ -596,6 +635,11 @@ async fn chat_completions( } }; + // append the tools to the inputs with TOOL prompt + let tool_prompt = + "Based on the conversation, please choose the most appropriate tool to use:".to_string(); + inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools}\n\n"); + // build the request passing some parameters let generate_request = GenerateRequest { inputs: inputs.to_string(), @@ -617,7 +661,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: None, - grammar: None, + grammar, }, }; diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 40f31ce2..c9a32ff7 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -513,6 +513,9 @@ class GrammarLogitProcessor(LogitsProcessor): start_time = time.time() if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: schema = build_regex_from_object(schema) + elif grammar_type == GrammarType.OPTIONAL_GRAMMAR_TYPE_REGEX: + # TODO: use a better method to handle optional grammars + schema = f"({build_regex_from_object(schema)})|.*" elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: pass # schema is already a regex just here for clarity fsm = RegexFSM(schema, tokenizer)