feat: basic tool support via grammar composition

This commit is contained in:
drbh 2024-02-16 16:00:59 +00:00
parent ac5a1c6f51
commit 0f500f6d14
3 changed files with 89 additions and 2 deletions

View File

@ -520,6 +520,46 @@ pub(crate) struct ChatRequest {
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = 0.95)] #[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>, pub top_p: Option<f32>,
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
/// functions the model may generate JSON inputs for.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tools: Option<Vec<Tool>>,
}
// TODO: define and use better types for tools
// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
// enum ToolType {
// #[serde(rename = "function")]
// Function,
// }
// impl Default for ToolType {
// fn default() -> Self {
// ToolType::Function
// }
// }
// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
// pub(crate) struct Function {
// pub description: String,
// pub name: String,
// #[serde(
// rename = "json",
// deserialize_with = "json_object_or_string_to_string::deserialize"
// )]
// pub parameters: String,
// }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct Tool {
// The type of the tool. Currently, only 'function' is supported.
#[schema(example = "function")]
pub r#type: String,
// Grab the tool as generic JSON for debugging purposes.
pub function: serde_json::Value,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]

View File

@ -580,8 +580,47 @@ async fn chat_completions(
let logprobs = req.logprobs.unwrap_or(false); let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed; let seed = req.seed;
// Build a new JSON schema that defines the "$functions" object
// and requires the grammar to choose anyOf the functions defined.
let mut tools = serde_json::json!({});
// First decompose the tools and use the function name as the key
// and the parameters as the value in the "$functions" object.
if let Some(req_tools) = &req.tools {
for tool in req_tools {
let func = tool.function.clone();
let name = func.get("name").unwrap().as_str().unwrap();
let parameters = func.get("parameters").unwrap().as_object().unwrap().clone();
// add a entry to the "$functions" object
tools["$functions"][name] = serde_json::Value::Object(parameters);
}
// now add the properties to the root object
tools["properties"]["function"]["anyOf"] = serde_json::Value::Array(
req.tools
.as_ref()
.unwrap()
.iter()
// map each tool to a $ref to the function
.map(|tool| {
let func = tool.function.clone();
let name = func.get("name").unwrap().as_str().unwrap();
serde_json::json!({
"$ref": format!("#/$functions/{}", name)
})
})
.collect(),
);
}
// only add grammar if tools are present
let grammar = match req.tools {
Some(_grammar) => Some(crate::GrammarType::Json(tools.to_string())),
None => None,
};
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(req.messages) { let mut inputs = match infer.apply_chat_template(req.messages) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@ -596,6 +635,11 @@ async fn chat_completions(
} }
}; };
// append the tools to the inputs with TOOL prompt
let tool_prompt =
"Based on the conversation, please choose the most appropriate tool to use:".to_string();
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools}\n\n");
// build the request passing some parameters // build the request passing some parameters
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: inputs.to_string(), inputs: inputs.to_string(),
@ -617,7 +661,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar,
}, },
}; };

View File

@ -513,6 +513,9 @@ class GrammarLogitProcessor(LogitsProcessor):
start_time = time.time() start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema) schema = build_regex_from_object(schema)
elif grammar_type == GrammarType.OPTIONAL_GRAMMAR_TYPE_REGEX:
# TODO: use a better method to handle optional grammars
schema = f"({build_regex_from_object(schema)})|.*"
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer) fsm = RegexFSM(schema, tokenizer)