mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: add concrete tool types
This commit is contained in:
parent
1aa2126206
commit
014d3fd4ef
@ -528,30 +528,53 @@ pub(crate) struct ChatRequest {
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
}
|
||||
|
||||
// TODO: define and use better types for tools
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||
pub struct Tools {
|
||||
// rename to "$function" to avoid conflicts with other fields
|
||||
#[serde(rename = "$function")]
|
||||
pub function: std::collections::HashMap<String, serde_json::Value>,
|
||||
pub any_of: Vec<FunctionRef>,
|
||||
}
|
||||
|
||||
// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
// enum ToolType {
|
||||
// #[serde(rename = "function")]
|
||||
// Function,
|
||||
// }
|
||||
// add traut to convert to serde_json::Value for tools
|
||||
impl From<Tools> for serde_json::Value {
|
||||
fn from(tools: Tools) -> Self {
|
||||
println!("tools: {:?}", tools);
|
||||
let mut map = serde_json::Map::new();
|
||||
let mut functions = serde_json::Map::new();
|
||||
for (name, value) in tools.function {
|
||||
functions.insert(name, value);
|
||||
}
|
||||
map.insert("$functions".to_string(), serde_json::json!(functions));
|
||||
let mut properties = serde_json::Map::new();
|
||||
let mut function = serde_json::Map::new();
|
||||
function.insert("anyOf".to_string(), serde_json::json!(tools.any_of));
|
||||
properties.insert("function".to_string(), serde_json::json!(function));
|
||||
map.insert("properties".to_string(), serde_json::json!(properties));
|
||||
serde_json::Value::Object(map)
|
||||
}
|
||||
}
|
||||
|
||||
// impl Default for ToolType {
|
||||
// fn default() -> Self {
|
||||
// ToolType::Function
|
||||
// }
|
||||
// }
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub struct FunctionRef {
|
||||
#[serde(rename = "$ref")]
|
||||
pub _ref: String,
|
||||
}
|
||||
|
||||
// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
// pub(crate) struct Function {
|
||||
// pub description: String,
|
||||
// pub name: String,
|
||||
// #[serde(
|
||||
// rename = "json",
|
||||
// deserialize_with = "json_object_or_string_to_string::deserialize"
|
||||
// )]
|
||||
// pub parameters: String,
|
||||
// }
|
||||
impl FunctionRef {
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self {
|
||||
_ref: format!("#/$functions/{}", name),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct Function {
|
||||
pub description: String,
|
||||
pub name: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct Tool {
|
||||
@ -559,7 +582,7 @@ pub(crate) struct Tool {
|
||||
#[schema(example = "function")]
|
||||
pub r#type: String,
|
||||
// Grab the tool as generic JSON for debugging purposes.
|
||||
pub function: serde_json::Value,
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
|
@ -10,6 +10,7 @@ use crate::{
|
||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
||||
};
|
||||
use crate::{FunctionRef, Tools};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
@ -22,6 +23,8 @@ use futures::stream::StreamExt;
|
||||
use futures::Stream;
|
||||
use futures::TryStreamExt;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
@ -580,45 +583,6 @@ async fn chat_completions(
|
||||
let logprobs = req.logprobs.unwrap_or(false);
|
||||
let seed = req.seed;
|
||||
|
||||
// Build a new JSON schema that defines the "$functions" object
|
||||
// and requires the grammar to choose anyOf the functions defined.
|
||||
let mut tools = serde_json::json!({});
|
||||
|
||||
// First decompose the tools and use the function name as the key
|
||||
// and the parameters as the value in the "$functions" object.
|
||||
if let Some(req_tools) = &req.tools {
|
||||
for tool in req_tools {
|
||||
let func = tool.function.clone();
|
||||
let name = func.get("name").unwrap().as_str().unwrap();
|
||||
let parameters = func.get("parameters").unwrap().as_object().unwrap().clone();
|
||||
// add a entry to the "$functions" object
|
||||
tools["$functions"][name] = serde_json::Value::Object(parameters);
|
||||
}
|
||||
|
||||
// now add the properties to the root object
|
||||
tools["properties"]["function"]["anyOf"] = serde_json::Value::Array(
|
||||
req.tools
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
// map each tool to a $ref to the function
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
let name = func.get("name").unwrap().as_str().unwrap();
|
||||
serde_json::json!({
|
||||
"$ref": format!("#/$functions/{}", name)
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
|
||||
// only add grammar if tools are present
|
||||
let grammar = match req.tools {
|
||||
Some(_grammar) => Some(crate::GrammarType::Json(tools.to_string())),
|
||||
None => None,
|
||||
};
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||
Ok(inputs) => inputs,
|
||||
@ -635,10 +599,60 @@ async fn chat_completions(
|
||||
}
|
||||
};
|
||||
|
||||
// append the tools to the inputs with TOOL prompt
|
||||
let tool_prompt =
|
||||
"Based on the conversation, please choose the most appropriate tool to use:".to_string();
|
||||
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools}\n\n");
|
||||
// if theres a tools object, we need to decompose it and use the function name as the key
|
||||
// and the parameters as the value in the "$functions" object.
|
||||
let grammar = if let Some(req_tools) = &req.tools {
|
||||
let functions: HashMap<String, Value> = {
|
||||
let mut tools = HashMap::new();
|
||||
for tool in req_tools {
|
||||
let func = tool.function.clone();
|
||||
let name = func.name;
|
||||
let parameters = match func.parameters.as_object() {
|
||||
Some(parameters) => parameters.clone(),
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Input validation error".to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
tools.insert(name, Value::Object(parameters));
|
||||
}
|
||||
tools
|
||||
};
|
||||
|
||||
let tools = Tools {
|
||||
function: functions,
|
||||
any_of: req_tools
|
||||
.iter()
|
||||
.map(|tool| FunctionRef::new(&tool.function.name))
|
||||
.collect(),
|
||||
};
|
||||
|
||||
// update the input
|
||||
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_prompt =
|
||||
"Based on the conversation, please choose the most appropriate tool to use:"
|
||||
.to_string();
|
||||
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools_str}\n\n");
|
||||
|
||||
Some(GrammarType::Json(tools.into()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
|
Loading…
Reference in New Issue
Block a user