feat: add concrete tool types

This commit is contained in:
drbh 2024-02-22 04:19:47 +00:00
parent 1aa2126206
commit 014d3fd4ef
2 changed files with 102 additions and 65 deletions

View File

@ -528,30 +528,53 @@ pub(crate) struct ChatRequest {
pub tools: Option<Vec<Tool>>, pub tools: Option<Vec<Tool>>,
} }
// TODO: define and use better types for tools #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
pub struct Tools {
// rename to "$function" to avoid conflicts with other fields
#[serde(rename = "$function")]
pub function: std::collections::HashMap<String, serde_json::Value>,
pub any_of: Vec<FunctionRef>,
}
// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] // add traut to convert to serde_json::Value for tools
// enum ToolType { impl From<Tools> for serde_json::Value {
// #[serde(rename = "function")] fn from(tools: Tools) -> Self {
// Function, println!("tools: {:?}", tools);
// } let mut map = serde_json::Map::new();
let mut functions = serde_json::Map::new();
for (name, value) in tools.function {
functions.insert(name, value);
}
map.insert("$functions".to_string(), serde_json::json!(functions));
let mut properties = serde_json::Map::new();
let mut function = serde_json::Map::new();
function.insert("anyOf".to_string(), serde_json::json!(tools.any_of));
properties.insert("function".to_string(), serde_json::json!(function));
map.insert("properties".to_string(), serde_json::json!(properties));
serde_json::Value::Object(map)
}
}
// impl Default for ToolType { #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
// fn default() -> Self { pub struct FunctionRef {
// ToolType::Function #[serde(rename = "$ref")]
// } pub _ref: String,
// } }
// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] impl FunctionRef {
// pub(crate) struct Function { pub fn new(name: &str) -> Self {
// pub description: String, Self {
// pub name: String, _ref: format!("#/$functions/{}", name),
// #[serde( }
// rename = "json", }
// deserialize_with = "json_object_or_string_to_string::deserialize" }
// )]
// pub parameters: String, #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
// } pub(crate) struct Function {
pub description: String,
pub name: String,
pub parameters: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct Tool { pub(crate) struct Tool {
@ -559,7 +582,7 @@ pub(crate) struct Tool {
#[schema(example = "function")] #[schema(example = "function")]
pub r#type: String, pub r#type: String,
// Grab the tool as generic JSON for debugging purposes. // Grab the tool as generic JSON for debugging purposes.
pub function: serde_json::Value, pub function: Function,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]

View File

@ -10,6 +10,7 @@ use crate::{
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
}; };
use crate::{FunctionRef, Tools};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
@ -22,6 +23,8 @@ use futures::stream::StreamExt;
use futures::Stream; use futures::Stream;
use futures::TryStreamExt; use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value;
use std::collections::HashMap;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
@ -580,45 +583,6 @@ async fn chat_completions(
let logprobs = req.logprobs.unwrap_or(false); let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed; let seed = req.seed;
// Build a new JSON schema that defines the "$functions" object
// and requires the grammar to choose anyOf the functions defined.
let mut tools = serde_json::json!({});
// First decompose the tools and use the function name as the key
// and the parameters as the value in the "$functions" object.
if let Some(req_tools) = &req.tools {
for tool in req_tools {
let func = tool.function.clone();
let name = func.get("name").unwrap().as_str().unwrap();
let parameters = func.get("parameters").unwrap().as_object().unwrap().clone();
// add a entry to the "$functions" object
tools["$functions"][name] = serde_json::Value::Object(parameters);
}
// now add the properties to the root object
tools["properties"]["function"]["anyOf"] = serde_json::Value::Array(
req.tools
.as_ref()
.unwrap()
.iter()
// map each tool to a $ref to the function
.map(|tool| {
let func = tool.function.clone();
let name = func.get("name").unwrap().as_str().unwrap();
serde_json::json!({
"$ref": format!("#/$functions/{}", name)
})
})
.collect(),
);
}
// only add grammar if tools are present
let grammar = match req.tools {
Some(_grammar) => Some(crate::GrammarType::Json(tools.to_string())),
None => None,
};
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let mut inputs = match infer.apply_chat_template(req.messages) { let mut inputs = match infer.apply_chat_template(req.messages) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
@ -635,10 +599,60 @@ async fn chat_completions(
} }
}; };
// append the tools to the inputs with TOOL prompt // if theres a tools object, we need to decompose it and use the function name as the key
// and the parameters as the value in the "$functions" object.
let grammar = if let Some(req_tools) = &req.tools {
let functions: HashMap<String, Value> = {
let mut tools = HashMap::new();
for tool in req_tools {
let func = tool.function.clone();
let name = func.name;
let parameters = match func.parameters.as_object() {
Some(parameters) => parameters.clone(),
None => {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
))
}
};
tools.insert(name, Value::Object(parameters));
}
tools
};
let tools = Tools {
function: functions,
any_of: req_tools
.iter()
.map(|tool| FunctionRef::new(&tool.function.name))
.collect(),
};
// update the input
let tools_str = serde_json::to_string(&tools).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
let tool_prompt = let tool_prompt =
"Based on the conversation, please choose the most appropriate tool to use:".to_string(); "Based on the conversation, please choose the most appropriate tool to use:"
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools}\n\n"); .to_string();
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools_str}\n\n");
Some(GrammarType::Json(tools.into()))
} else {
None
};
// build the request passing some parameters // build the request passing some parameters
let generate_request = GenerateRequest { let generate_request = GenerateRequest {