feat: improve tools api and add tool prompt

This commit is contained in:
drbh 2024-02-24 01:58:54 +00:00
parent 0e30e65822
commit a32d3dd6cb
5 changed files with 163 additions and 74 deletions

View File

@ -55,7 +55,6 @@ enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
GRAMMAR_TYPE_OPTIONAL_JSON = 3;
}
message NextTokenChooserParameters {

View File

@ -812,23 +812,27 @@ mod tests {
messages: vec![
Message {
role: "user".to_string(),
content: "Hi!".to_string(),
content: Some("Hi!".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
content: Some("Hello how can I help?".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
content: Some("What is Deep Learning?".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: "magic!".to_string(),
content: Some("magic!".to_string()),
name: None,
tool_calls: None,
},
],
bos_token: Some("[BOS]"),
@ -877,28 +881,33 @@ mod tests {
messages: vec![
Message {
role: "user".to_string(),
content: "Hi!".to_string(),
content: Some("Hi!".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "user".to_string(),
content: "Hi again!".to_string(),
content: Some("Hi again!".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
content: Some("Hello how can I help?".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
content: Some("What is Deep Learning?".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: "magic!".to_string(),
content: Some("magic!".to_string()),
name: None,
tool_calls: None,
},
],
bos_token: Some("[BOS]"),
@ -952,23 +961,27 @@ mod tests {
messages: vec![
Message {
role: "user".to_string(),
content: "Hi!".to_string(),
content: Some("Hi!".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
content: Some("Hello how can I help?".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
content: Some("What is Deep Learning?".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: "magic!".to_string(),
content: Some("magic!".to_string()),
name: None,
tool_calls: None,
},
],
bos_token: Some("[BOS]"),
@ -1006,23 +1019,27 @@ mod tests {
messages: vec![
Message {
role: "user".to_string(),
content: "Hi!".to_string(),
content: Some("Hi!".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
content: Some("Hello how can I help?".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
content: Some("What is Deep Learning?".to_string()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: "magic!".to_string(),
content: Some("magic!".to_string()),
name: None,
tool_calls: None,
},
],
bos_token: Some("[BOS]"),

View File

@ -358,10 +358,11 @@ impl ChatCompletion {
pub(crate) fn new(
model: String,
system_fingerprint: String,
output: String,
output: Option<String>,
created: u64,
details: Details,
return_logprobs: bool,
tool_calls: Option<ToolCall>,
) -> Self {
Self {
id: String::new(),
@ -375,6 +376,7 @@ impl ChatCompletion {
role: "assistant".into(),
content: output,
name: None,
tool_calls,
},
logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
@ -527,10 +529,61 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = "null")]
pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools
#[serde(default = "default_tool_prompt")]
pub tool_prompt: Option<String>,
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tool_choice: Option<String>,
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
pub tool_choice: Option<ToolType>,
}
fn default_tool_prompt() -> Option<String> {
Some(
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
)
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
enum ToolType {
FunctionName(String),
OneOf,
}
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
mod deserialize_tool_choice {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => match s.as_str() {
"none" => Ok(None),
"auto" => Ok(Some(ToolType::OneOf)),
_ => Ok(Some(ToolType::FunctionName(s))),
},
Value::Object(map) => {
if let Some(content) = map
.get("function")
.and_then(|v| v.get("name"))
.and_then(|v| v.as_str())
{
Ok(Some(ToolType::FunctionName(content.to_string())))
} else {
Err(de::Error::custom("function key not found in tool choice"))
}
}
_ => Err(de::Error::custom("invalid token format")),
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
@ -575,7 +628,8 @@ impl FunctionRef {
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct Function {
pub description: String,
#[serde(default)]
pub description: Option<String>,
pub name: String,
pub parameters: serde_json::Value,
}
@ -597,15 +651,24 @@ pub(crate) struct ChatTemplateInputs<'a> {
add_generation_prompt: bool,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ToolCall {
pub id: u32,
pub r#type: String,
pub function: Function,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct Message {
#[schema(example = "user")]
pub role: String,
#[schema(example = "My name is David and I")]
pub content: String,
pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")]
pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<ToolCall>,
}
#[derive(Clone, Debug, Deserialize, ToSchema)]

View File

@ -10,7 +10,7 @@ use crate::{
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
};
use crate::{FunctionRef, Tools};
use crate::{Function, FunctionRef, ToolCall, ToolType, Tools};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
@ -583,6 +583,16 @@ async fn chat_completions(
let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed;
if stream && req.tools.is_some() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Tools are not supported with stream".to_string(),
error_type: "Input validation error".to_string(),
}),
));
}
// apply chat template to flatten the request into a single input
let mut inputs = match infer.apply_chat_template(req.messages) {
Ok(inputs) => inputs,
@ -599,53 +609,35 @@ async fn chat_completions(
}
};
// if theres a tools object, we need to decompose it and use the function name as the key
// and the parameters as the value in the "$functions" object.
let grammar = if let Some(ref req_tools) = &req.tools {
// get the tool_choice if there is one
let tool_choice = &req.tool_choice;
let tools_to_use = if let Some(tool_choice) = tool_choice {
// get the tool based on the tool_choice
let tool = req_tools
.iter()
.find(|tool| tool.function.name == *tool_choice)
.ok_or_else(|| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
vec![tool.clone()]
} else {
req_tools.clone()
};
let functions: HashMap<String, Value> = {
let mut tools = HashMap::new();
for tool in &tools_to_use {
let func = tool.function.clone();
let name = func.name;
let parameters = match func.parameters.as_object() {
Some(parameters) => parameters.clone(),
None => {
return Err((
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
let tool_prompt = req.tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.ok_or_else(|| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
))
}
};
tools.insert(name, Value::Object(parameters));
)
})?
.clone()]
}
tools
ToolType::OneOf => req_tools.to_owned(),
};
let functions: HashMap<String, Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
(func.name, func.parameters)
})
.collect();
let tools = Tools {
function: functions,
any_of: tools_to_use
@ -654,7 +646,6 @@ async fn chat_completions(
.collect(),
};
// update the input
let tools_str = serde_json::to_string(&tools).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
@ -664,12 +655,7 @@ async fn chat_completions(
}),
)
})?;
let tool_prompt =
"Based on the conversation, please choose the most appropriate tool to use:"
.to_string();
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools_str}\n\n");
inputs = format!("{inputs}{tool_prompt}{tools_str}");
Some(GrammarType::Json(tools.into()))
} else {
None
@ -696,7 +682,7 @@ async fn chat_completions(
decoder_input_details: !stream,
seed,
top_n_tokens: None,
grammar,
grammar: tool_grammar.clone(),
},
};
@ -760,14 +746,41 @@ async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let (tool_calls, output) = if tool_grammar.is_some() {
// gen_text should be valid json
let gen_text_value: Value =
serde_json::from_str(&generation.generated_text).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
let tool_call = Some(ToolCall {
id: 0,
r#type: "function".to_string(),
function: Function {
description: None,
name: "tools".to_string(),
parameters: gen_text_value.get("function").unwrap().clone(),
},
});
(tool_call, None)
} else {
(None, Some(generation.generated_text))
};
// build the complete response object with the full text
let response = ChatCompletion::new(
model_id,
system_fingerprint,
generation.generated_text,
output,
current_time,
generation.details.unwrap(),
logprobs,
tool_calls,
);
// wrap generation inside a Vec to match api-inference

View File

@ -513,9 +513,6 @@ class GrammarLogitProcessor(LogitsProcessor):
start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_OPTIONAL_JSON:
# TODO: use a better method to handle optional grammars
schema = f"({build_regex_from_object(schema)})|.*"
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer)