text-generation-inference/router/src/infer/tool_grammar.rs
drbh 8ad20daf33
CI (2599): Update ToolType input schema (#2601)
* Update ToolType input schema

* lint

* fix: run formatter

* fix: allow tool choide to be null

---------

Co-authored-by: Wauplin <lucainp@gmail.com>
2024-10-08 12:35:48 -04:00

124 lines
4.2 KiB
Rust

use crate::infer::InferError;
use crate::{
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
ToolType,
};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
pub(crate) struct ToolGrammar {}
impl ToolGrammar {
// find a tool by name
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
tools
.iter()
.find(|tool| tool.function.name == name)
.cloned()
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
}
pub fn apply(
tools: Vec<Tool>,
tool_choice: ToolChoice,
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
// if no tools are provided, we return None
if tools.is_empty() {
return Ok((tools, None));
}
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
let mut tools = tools.clone();
// add the notify_error function to the tools
let notify_error = Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "notify_error".to_string(),
description: Some("Notify an error or issue".to_string()),
arguments: json!({
"type": "object",
"properties": {
"error": {
"type": "string",
"description": "The error or issue to notify"
}
},
"required": ["error"]
}),
},
};
tools.push(notify_error);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::Function(function) => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools.clone(),
ToolType::NoTool => return Ok((tools, None)),
};
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
let mut params = Map::new();
params.insert(
"description".to_string(),
Value::String(func.description.unwrap_or_default()),
);
let mut properties = Map::new();
let mut required = vec![Value::String("_name".to_string())];
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
}),
);
if let Value::Object(args) = func.arguments {
if let Some(Value::Object(props)) = args.get("properties") {
properties.extend(props.clone());
}
if let Some(Value::Array(reqs)) = args.get("required") {
required.extend(reqs.clone());
}
params.insert(
"additionalProperties".to_string(),
Value::Bool(
args.get("additionalProperties").and_then(|v| v.as_str())
== Some("true"),
),
);
}
params.insert("properties".to_string(), Value::Object(properties));
params.insert("required".to_string(), Value::Array(required));
(func.name, Value::Object(params))
})
.collect();
let tool_schema = JsonSchemaTool {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.collect(),
},
};
Ok((tools, Some(tool_schema)))
}
}