feat: improve tool serialization

This commit is contained in:
drbh 2024-02-27 17:52:46 +00:00
parent f72155ae46
commit 4a81dd042f
2 changed files with 36 additions and 37 deletions

View File

@ -591,44 +591,39 @@ mod deserialize_tool_choice {
}
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
#[derive(Debug, Deserialize, Serialize, ToSchema)]
pub struct Tools {
// rename to "$function" to avoid conflicts with other fields
#[serde(rename = "$function")]
pub function: std::collections::HashMap<String, serde_json::Value>,
pub any_of: Vec<FunctionRef>,
#[serde(flatten)]
functions_map: FunctionsMap,
properties: Properties,
}
// Allows Tools to be converted to a valid JSON schema object
impl From<Tools> for serde_json::Value {
fn from(tools: Tools) -> Self {
let mut map = serde_json::Map::new();
let mut functions = serde_json::Map::new();
for (name, value) in tools.function {
functions.insert(name, value);
}
map.insert("$functions".to_string(), serde_json::json!(functions));
let mut properties = serde_json::Map::new();
let mut function = serde_json::Map::new();
function.insert("anyOf".to_string(), serde_json::json!(tools.any_of));
properties.insert("function".to_string(), serde_json::json!(function));
map.insert("properties".to_string(), serde_json::json!(properties));
serde_json::Value::Object(map)
}
#[derive(Debug, Serialize, Deserialize)]
struct FunctionsMap {
#[serde(rename = "$functions")]
functions: std::collections::HashMap<String, serde_json::Value>,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub struct FunctionRef {
#[derive(Debug, Serialize, Deserialize)]
struct FunctionRef {
#[serde(rename = "$ref")]
pub _ref: String,
ref_path: String,
}
impl FunctionRef {
pub fn new(name: &str) -> Self {
Self {
_ref: format!("#/$functions/{}", name),
}
}
#[derive(Debug, Serialize, Deserialize)]
struct Properties {
#[serde(serialize_with = "serialize_function")]
function: Vec<FunctionRef>,
}
fn serialize_function<S>(functions: &Vec<FunctionRef>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("Function", 1)?;
state.serialize_field("anyOf", functions)?;
state.end()
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]

View File

@ -10,7 +10,7 @@ use crate::{
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
};
use crate::{Function, FunctionRef, ToolCall, ToolType, Tools};
use crate::{Function, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
@ -629,11 +629,15 @@ async fn chat_completions(
.collect();
let tools = Tools {
function: functions,
any_of: tools_to_use
.iter()
.map(|tool| FunctionRef::new(&tool.function.name))
.collect(),
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.collect(),
},
};
let tools_str = serde_json::to_string(&tools).map_err(|e| {
@ -646,7 +650,7 @@ async fn chat_completions(
)
})?;
inputs = format!("{inputs}{tool_prompt}{tools_str}");
Some(GrammarType::Json(tools.into()))
Some(GrammarType::Json(serde_json::json!(tools)))
} else {
None
};