mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve tool serialization
This commit is contained in:
parent
f72155ae46
commit
4a81dd042f
@ -591,44 +591,39 @@ mod deserialize_tool_choice {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||
#[derive(Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub struct Tools {
|
||||
// rename to "$function" to avoid conflicts with other fields
|
||||
#[serde(rename = "$function")]
|
||||
pub function: std::collections::HashMap<String, serde_json::Value>,
|
||||
pub any_of: Vec<FunctionRef>,
|
||||
#[serde(flatten)]
|
||||
functions_map: FunctionsMap,
|
||||
properties: Properties,
|
||||
}
|
||||
|
||||
// Allows Tools to be converted to a valid JSON schema object
|
||||
impl From<Tools> for serde_json::Value {
|
||||
fn from(tools: Tools) -> Self {
|
||||
let mut map = serde_json::Map::new();
|
||||
let mut functions = serde_json::Map::new();
|
||||
for (name, value) in tools.function {
|
||||
functions.insert(name, value);
|
||||
}
|
||||
map.insert("$functions".to_string(), serde_json::json!(functions));
|
||||
let mut properties = serde_json::Map::new();
|
||||
let mut function = serde_json::Map::new();
|
||||
function.insert("anyOf".to_string(), serde_json::json!(tools.any_of));
|
||||
properties.insert("function".to_string(), serde_json::json!(function));
|
||||
map.insert("properties".to_string(), serde_json::json!(properties));
|
||||
serde_json::Value::Object(map)
|
||||
}
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FunctionsMap {
|
||||
#[serde(rename = "$functions")]
|
||||
functions: std::collections::HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub struct FunctionRef {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FunctionRef {
|
||||
#[serde(rename = "$ref")]
|
||||
pub _ref: String,
|
||||
ref_path: String,
|
||||
}
|
||||
|
||||
impl FunctionRef {
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self {
|
||||
_ref: format!("#/$functions/{}", name),
|
||||
}
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Properties {
|
||||
#[serde(serialize_with = "serialize_function")]
|
||||
function: Vec<FunctionRef>,
|
||||
}
|
||||
|
||||
fn serialize_function<S>(functions: &Vec<FunctionRef>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
use serde::ser::SerializeStruct;
|
||||
let mut state = serializer.serialize_struct("Function", 1)?;
|
||||
state.serialize_field("anyOf", functions)?;
|
||||
state.end()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
|
@ -10,7 +10,7 @@ use crate::{
|
||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
||||
};
|
||||
use crate::{Function, FunctionRef, ToolCall, ToolType, Tools};
|
||||
use crate::{Function, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
@ -629,11 +629,15 @@ async fn chat_completions(
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
function: functions,
|
||||
any_of: tools_to_use
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef::new(&tool.function.name))
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||
@ -646,7 +650,7 @@ async fn chat_completions(
|
||||
)
|
||||
})?;
|
||||
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
||||
Some(GrammarType::Json(tools.into()))
|
||||
Some(GrammarType::Json(serde_json::json!(tools)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user