mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: basic tool support via grammar composition
This commit is contained in:
parent
ac5a1c6f51
commit
0f500f6d14
@ -520,6 +520,46 @@ pub(crate) struct ChatRequest {
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = 0.95)]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
|
||||
/// functions the model may generate JSON inputs for.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
}
|
||||
|
||||
// TODO: define and use better types for tools
|
||||
|
||||
// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
// enum ToolType {
|
||||
// #[serde(rename = "function")]
|
||||
// Function,
|
||||
// }
|
||||
|
||||
// impl Default for ToolType {
|
||||
// fn default() -> Self {
|
||||
// ToolType::Function
|
||||
// }
|
||||
// }
|
||||
|
||||
// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
// pub(crate) struct Function {
|
||||
// pub description: String,
|
||||
// pub name: String,
|
||||
// #[serde(
|
||||
// rename = "json",
|
||||
// deserialize_with = "json_object_or_string_to_string::deserialize"
|
||||
// )]
|
||||
// pub parameters: String,
|
||||
// }
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct Tool {
|
||||
// The type of the tool. Currently, only 'function' is supported.
|
||||
#[schema(example = "function")]
|
||||
pub r#type: String,
|
||||
// Grab the tool as generic JSON for debugging purposes.
|
||||
pub function: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
|
@ -580,8 +580,47 @@ async fn chat_completions(
|
||||
let logprobs = req.logprobs.unwrap_or(false);
|
||||
let seed = req.seed;
|
||||
|
||||
// Build a new JSON schema that defines the "$functions" object
|
||||
// and requires the grammar to choose anyOf the functions defined.
|
||||
let mut tools = serde_json::json!({});
|
||||
|
||||
// First decompose the tools and use the function name as the key
|
||||
// and the parameters as the value in the "$functions" object.
|
||||
if let Some(req_tools) = &req.tools {
|
||||
for tool in req_tools {
|
||||
let func = tool.function.clone();
|
||||
let name = func.get("name").unwrap().as_str().unwrap();
|
||||
let parameters = func.get("parameters").unwrap().as_object().unwrap().clone();
|
||||
// add a entry to the "$functions" object
|
||||
tools["$functions"][name] = serde_json::Value::Object(parameters);
|
||||
}
|
||||
|
||||
// now add the properties to the root object
|
||||
tools["properties"]["function"]["anyOf"] = serde_json::Value::Array(
|
||||
req.tools
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
// map each tool to a $ref to the function
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
let name = func.get("name").unwrap().as_str().unwrap();
|
||||
serde_json::json!({
|
||||
"$ref": format!("#/$functions/{}", name)
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
|
||||
// only add grammar if tools are present
|
||||
let grammar = match req.tools {
|
||||
Some(_grammar) => Some(crate::GrammarType::Json(tools.to_string())),
|
||||
None => None,
|
||||
};
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let inputs = match infer.apply_chat_template(req.messages) {
|
||||
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
@ -596,6 +635,11 @@ async fn chat_completions(
|
||||
}
|
||||
};
|
||||
|
||||
// append the tools to the inputs with TOOL prompt
|
||||
let tool_prompt =
|
||||
"Based on the conversation, please choose the most appropriate tool to use:".to_string();
|
||||
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools}\n\n");
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
@ -617,7 +661,7 @@ async fn chat_completions(
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
grammar,
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -513,6 +513,9 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
start_time = time.time()
|
||||
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||
schema = build_regex_from_object(schema)
|
||||
elif grammar_type == GrammarType.OPTIONAL_GRAMMAR_TYPE_REGEX:
|
||||
# TODO: use a better method to handle optional grammars
|
||||
schema = f"({build_regex_from_object(schema)})|.*"
|
||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
||||
pass # schema is already a regex just here for clarity
|
||||
fsm = RegexFSM(schema, tokenizer)
|
||||
|
Loading…
Reference in New Issue
Block a user