mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve tool serialization
This commit is contained in:
parent
f72155ae46
commit
4a81dd042f
@ -591,44 +591,39 @@ mod deserialize_tool_choice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
#[derive(Debug, Deserialize, Serialize, ToSchema)]
|
||||||
pub struct Tools {
|
pub struct Tools {
|
||||||
// rename to "$function" to avoid conflicts with other fields
|
#[serde(flatten)]
|
||||||
#[serde(rename = "$function")]
|
functions_map: FunctionsMap,
|
||||||
pub function: std::collections::HashMap<String, serde_json::Value>,
|
properties: Properties,
|
||||||
pub any_of: Vec<FunctionRef>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allows Tools to be converted to a valid JSON schema object
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
impl From<Tools> for serde_json::Value {
|
struct FunctionsMap {
|
||||||
fn from(tools: Tools) -> Self {
|
#[serde(rename = "$functions")]
|
||||||
let mut map = serde_json::Map::new();
|
functions: std::collections::HashMap<String, serde_json::Value>,
|
||||||
let mut functions = serde_json::Map::new();
|
|
||||||
for (name, value) in tools.function {
|
|
||||||
functions.insert(name, value);
|
|
||||||
}
|
|
||||||
map.insert("$functions".to_string(), serde_json::json!(functions));
|
|
||||||
let mut properties = serde_json::Map::new();
|
|
||||||
let mut function = serde_json::Map::new();
|
|
||||||
function.insert("anyOf".to_string(), serde_json::json!(tools.any_of));
|
|
||||||
properties.insert("function".to_string(), serde_json::json!(function));
|
|
||||||
map.insert("properties".to_string(), serde_json::json!(properties));
|
|
||||||
serde_json::Value::Object(map)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct FunctionRef {
|
struct FunctionRef {
|
||||||
#[serde(rename = "$ref")]
|
#[serde(rename = "$ref")]
|
||||||
pub _ref: String,
|
ref_path: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FunctionRef {
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub fn new(name: &str) -> Self {
|
struct Properties {
|
||||||
Self {
|
#[serde(serialize_with = "serialize_function")]
|
||||||
_ref: format!("#/$functions/{}", name),
|
function: Vec<FunctionRef>,
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
fn serialize_function<S>(functions: &Vec<FunctionRef>, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
use serde::ser::SerializeStruct;
|
||||||
|
let mut state = serializer.serialize_struct("Function", 1)?;
|
||||||
|
state.serialize_field("anyOf", functions)?;
|
||||||
|
state.end()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
|
@ -10,7 +10,7 @@ use crate::{
|
|||||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{Function, FunctionRef, ToolCall, ToolType, Tools};
|
use crate::{Function, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
@ -629,11 +629,15 @@ async fn chat_completions(
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let tools = Tools {
|
let tools = Tools {
|
||||||
function: functions,
|
functions_map: FunctionsMap { functions },
|
||||||
any_of: tools_to_use
|
properties: Properties {
|
||||||
.iter()
|
function: tools_to_use
|
||||||
.map(|tool| FunctionRef::new(&tool.function.name))
|
.iter()
|
||||||
.collect(),
|
.map(|tool| FunctionRef {
|
||||||
|
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||||
@ -646,7 +650,7 @@ async fn chat_completions(
|
|||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
||||||
Some(GrammarType::Json(tools.into()))
|
Some(GrammarType::Json(serde_json::json!(tools)))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user