mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: add concrete tool types
This commit is contained in:
parent
1aa2126206
commit
014d3fd4ef
@ -528,30 +528,53 @@ pub(crate) struct ChatRequest {
|
|||||||
pub tools: Option<Vec<Tool>>,
|
pub tools: Option<Vec<Tool>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: define and use better types for tools
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||||
|
pub struct Tools {
|
||||||
|
// rename to "$function" to avoid conflicts with other fields
|
||||||
|
#[serde(rename = "$function")]
|
||||||
|
pub function: std::collections::HashMap<String, serde_json::Value>,
|
||||||
|
pub any_of: Vec<FunctionRef>,
|
||||||
|
}
|
||||||
|
|
||||||
// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
// add traut to convert to serde_json::Value for tools
|
||||||
// enum ToolType {
|
impl From<Tools> for serde_json::Value {
|
||||||
// #[serde(rename = "function")]
|
fn from(tools: Tools) -> Self {
|
||||||
// Function,
|
println!("tools: {:?}", tools);
|
||||||
// }
|
let mut map = serde_json::Map::new();
|
||||||
|
let mut functions = serde_json::Map::new();
|
||||||
|
for (name, value) in tools.function {
|
||||||
|
functions.insert(name, value);
|
||||||
|
}
|
||||||
|
map.insert("$functions".to_string(), serde_json::json!(functions));
|
||||||
|
let mut properties = serde_json::Map::new();
|
||||||
|
let mut function = serde_json::Map::new();
|
||||||
|
function.insert("anyOf".to_string(), serde_json::json!(tools.any_of));
|
||||||
|
properties.insert("function".to_string(), serde_json::json!(function));
|
||||||
|
map.insert("properties".to_string(), serde_json::json!(properties));
|
||||||
|
serde_json::Value::Object(map)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// impl Default for ToolType {
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
// fn default() -> Self {
|
pub struct FunctionRef {
|
||||||
// ToolType::Function
|
#[serde(rename = "$ref")]
|
||||||
// }
|
pub _ref: String,
|
||||||
// }
|
}
|
||||||
|
|
||||||
// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
impl FunctionRef {
|
||||||
// pub(crate) struct Function {
|
pub fn new(name: &str) -> Self {
|
||||||
// pub description: String,
|
Self {
|
||||||
// pub name: String,
|
_ref: format!("#/$functions/{}", name),
|
||||||
// #[serde(
|
}
|
||||||
// rename = "json",
|
}
|
||||||
// deserialize_with = "json_object_or_string_to_string::deserialize"
|
}
|
||||||
// )]
|
|
||||||
// pub parameters: String,
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
// }
|
pub(crate) struct Function {
|
||||||
|
pub description: String,
|
||||||
|
pub name: String,
|
||||||
|
pub parameters: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct Tool {
|
pub(crate) struct Tool {
|
||||||
@ -559,7 +582,7 @@ pub(crate) struct Tool {
|
|||||||
#[schema(example = "function")]
|
#[schema(example = "function")]
|
||||||
pub r#type: String,
|
pub r#type: String,
|
||||||
// Grab the tool as generic JSON for debugging purposes.
|
// Grab the tool as generic JSON for debugging purposes.
|
||||||
pub function: serde_json::Value,
|
pub function: Function,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
|
@ -10,6 +10,7 @@ use crate::{
|
|||||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
|
use crate::{FunctionRef, Tools};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
@ -22,6 +23,8 @@ use futures::stream::StreamExt;
|
|||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::atomic::AtomicBool;
|
use std::sync::atomic::AtomicBool;
|
||||||
@ -580,45 +583,6 @@ async fn chat_completions(
|
|||||||
let logprobs = req.logprobs.unwrap_or(false);
|
let logprobs = req.logprobs.unwrap_or(false);
|
||||||
let seed = req.seed;
|
let seed = req.seed;
|
||||||
|
|
||||||
// Build a new JSON schema that defines the "$functions" object
|
|
||||||
// and requires the grammar to choose anyOf the functions defined.
|
|
||||||
let mut tools = serde_json::json!({});
|
|
||||||
|
|
||||||
// First decompose the tools and use the function name as the key
|
|
||||||
// and the parameters as the value in the "$functions" object.
|
|
||||||
if let Some(req_tools) = &req.tools {
|
|
||||||
for tool in req_tools {
|
|
||||||
let func = tool.function.clone();
|
|
||||||
let name = func.get("name").unwrap().as_str().unwrap();
|
|
||||||
let parameters = func.get("parameters").unwrap().as_object().unwrap().clone();
|
|
||||||
// add a entry to the "$functions" object
|
|
||||||
tools["$functions"][name] = serde_json::Value::Object(parameters);
|
|
||||||
}
|
|
||||||
|
|
||||||
// now add the properties to the root object
|
|
||||||
tools["properties"]["function"]["anyOf"] = serde_json::Value::Array(
|
|
||||||
req.tools
|
|
||||||
.as_ref()
|
|
||||||
.unwrap()
|
|
||||||
.iter()
|
|
||||||
// map each tool to a $ref to the function
|
|
||||||
.map(|tool| {
|
|
||||||
let func = tool.function.clone();
|
|
||||||
let name = func.get("name").unwrap().as_str().unwrap();
|
|
||||||
serde_json::json!({
|
|
||||||
"$ref": format!("#/$functions/{}", name)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// only add grammar if tools are present
|
|
||||||
let grammar = match req.tools {
|
|
||||||
Some(_grammar) => Some(crate::GrammarType::Json(tools.to_string())),
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
let mut inputs = match infer.apply_chat_template(req.messages) {
|
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||||
Ok(inputs) => inputs,
|
Ok(inputs) => inputs,
|
||||||
@ -635,10 +599,60 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// append the tools to the inputs with TOOL prompt
|
// if theres a tools object, we need to decompose it and use the function name as the key
|
||||||
|
// and the parameters as the value in the "$functions" object.
|
||||||
|
let grammar = if let Some(req_tools) = &req.tools {
|
||||||
|
let functions: HashMap<String, Value> = {
|
||||||
|
let mut tools = HashMap::new();
|
||||||
|
for tool in req_tools {
|
||||||
|
let func = tool.function.clone();
|
||||||
|
let name = func.name;
|
||||||
|
let parameters = match func.parameters.as_object() {
|
||||||
|
Some(parameters) => parameters.clone(),
|
||||||
|
None => {
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Input validation error".to_string(),
|
||||||
|
error_type: "Input validation error".to_string(),
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
tools.insert(name, Value::Object(parameters));
|
||||||
|
}
|
||||||
|
tools
|
||||||
|
};
|
||||||
|
|
||||||
|
let tools = Tools {
|
||||||
|
function: functions,
|
||||||
|
any_of: req_tools
|
||||||
|
.iter()
|
||||||
|
.map(|tool| FunctionRef::new(&tool.function.name))
|
||||||
|
.collect(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// update the input
|
||||||
|
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||||
|
(
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: e.to_string(),
|
||||||
|
error_type: "Input validation error".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
let tool_prompt =
|
let tool_prompt =
|
||||||
"Based on the conversation, please choose the most appropriate tool to use:".to_string();
|
"Based on the conversation, please choose the most appropriate tool to use:"
|
||||||
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools}\n\n");
|
.to_string();
|
||||||
|
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools_str}\n\n");
|
||||||
|
|
||||||
|
Some(GrammarType::Json(tools.into()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
// build the request passing some parameters
|
// build the request passing some parameters
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
|
Loading…
Reference in New Issue
Block a user