feat: improve tools api and add tool prompt

This commit is contained in:
drbh 2024-02-24 01:58:54 +00:00
parent 0e30e65822
commit a32d3dd6cb
5 changed files with 163 additions and 74 deletions

View File

@ -55,7 +55,6 @@ enum GrammarType {
GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1; GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2; GRAMMAR_TYPE_REGEX = 2;
GRAMMAR_TYPE_OPTIONAL_JSON = 3;
} }
message NextTokenChooserParameters { message NextTokenChooserParameters {

View File

@ -812,23 +812,27 @@ mod tests {
messages: vec![ messages: vec![
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "magic!".to_string(), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -877,28 +881,33 @@ mod tests {
messages: vec![ messages: vec![
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi again!".to_string(), content: Some("Hi again!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "magic!".to_string(), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -952,23 +961,27 @@ mod tests {
messages: vec![ messages: vec![
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "magic!".to_string(), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1006,23 +1019,27 @@ mod tests {
messages: vec![ messages: vec![
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "magic!".to_string(), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),

View File

@ -358,10 +358,11 @@ impl ChatCompletion {
pub(crate) fn new( pub(crate) fn new(
model: String, model: String,
system_fingerprint: String, system_fingerprint: String,
output: String, output: Option<String>,
created: u64, created: u64,
details: Details, details: Details,
return_logprobs: bool, return_logprobs: bool,
tool_calls: Option<ToolCall>,
) -> Self { ) -> Self {
Self { Self {
id: String::new(), id: String::new(),
@ -375,6 +376,7 @@ impl ChatCompletion {
role: "assistant".into(), role: "assistant".into(),
content: output, content: output,
name: None, name: None,
tool_calls,
}, },
logprobs: return_logprobs logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
@ -527,10 +529,61 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]
pub tools: Option<Vec<Tool>>, pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools
#[serde(default = "default_tool_prompt")]
pub tool_prompt: Option<String>,
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]
pub tool_choice: Option<String>, #[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
pub tool_choice: Option<ToolType>,
}
fn default_tool_prompt() -> Option<String> {
Some(
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
)
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
enum ToolType {
FunctionName(String),
OneOf,
}
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
mod deserialize_tool_choice {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => match s.as_str() {
"none" => Ok(None),
"auto" => Ok(Some(ToolType::OneOf)),
_ => Ok(Some(ToolType::FunctionName(s))),
},
Value::Object(map) => {
if let Some(content) = map
.get("function")
.and_then(|v| v.get("name"))
.and_then(|v| v.as_str())
{
Ok(Some(ToolType::FunctionName(content.to_string())))
} else {
Err(de::Error::custom("function key not found in tool choice"))
}
}
_ => Err(de::Error::custom("invalid token format")),
}
}
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
@ -575,7 +628,8 @@ impl FunctionRef {
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct Function { pub(crate) struct Function {
pub description: String, #[serde(default)]
pub description: Option<String>,
pub name: String, pub name: String,
pub parameters: serde_json::Value, pub parameters: serde_json::Value,
} }
@ -597,15 +651,24 @@ pub(crate) struct ChatTemplateInputs<'a> {
add_generation_prompt: bool, add_generation_prompt: bool,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ToolCall {
pub id: u32,
pub r#type: String,
pub function: Function,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct Message { pub(crate) struct Message {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, pub role: String,
#[schema(example = "My name is David and I")] #[schema(example = "My name is David and I")]
pub content: String, pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")] #[schema(example = "\"David\"")]
pub name: Option<String>, pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<ToolCall>,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]

View File

@ -10,7 +10,7 @@ use crate::{
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
}; };
use crate::{FunctionRef, Tools}; use crate::{Function, FunctionRef, ToolCall, ToolType, Tools};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
@ -583,6 +583,16 @@ async fn chat_completions(
let logprobs = req.logprobs.unwrap_or(false); let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed; let seed = req.seed;
if stream && req.tools.is_some() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Tools are not supported with stream".to_string(),
error_type: "Input validation error".to_string(),
}),
));
}
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let mut inputs = match infer.apply_chat_template(req.messages) { let mut inputs = match infer.apply_chat_template(req.messages) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
@ -599,16 +609,13 @@ async fn chat_completions(
} }
}; };
// if theres a tools object, we need to decompose it and use the function name as the key let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
// and the parameters as the value in the "$functions" object. let tool_prompt = req.tool_prompt.unwrap_or_default();
let grammar = if let Some(ref req_tools) = &req.tools { let tools_to_use = match tool_choice {
// get the tool_choice if there is one ToolType::FunctionName(name) => {
let tool_choice = &req.tool_choice; vec![req_tools
let tools_to_use = if let Some(tool_choice) = tool_choice {
// get the tool based on the tool_choice
let tool = req_tools
.iter() .iter()
.find(|tool| tool.function.name == *tool_choice) .find(|tool| tool.function.name == *name)
.ok_or_else(|| { .ok_or_else(|| {
( (
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
@ -617,34 +624,19 @@ async fn chat_completions(
error_type: "Input validation error".to_string(), error_type: "Input validation error".to_string(),
}), }),
) )
})?; })?
vec![tool.clone()] .clone()]
} else { }
req_tools.clone() ToolType::OneOf => req_tools.to_owned(),
}; };
let functions: HashMap<String, Value> = { let functions: HashMap<String, Value> = tools_to_use
let mut tools = HashMap::new(); .iter()
for tool in &tools_to_use { .map(|tool| {
let func = tool.function.clone(); let func = tool.function.clone();
let name = func.name; (func.name, func.parameters)
let parameters = match func.parameters.as_object() { })
Some(parameters) => parameters.clone(), .collect();
None => {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
))
}
};
tools.insert(name, Value::Object(parameters));
}
tools
};
let tools = Tools { let tools = Tools {
function: functions, function: functions,
@ -654,7 +646,6 @@ async fn chat_completions(
.collect(), .collect(),
}; };
// update the input
let tools_str = serde_json::to_string(&tools).map_err(|e| { let tools_str = serde_json::to_string(&tools).map_err(|e| {
( (
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
@ -664,12 +655,7 @@ async fn chat_completions(
}), }),
) )
})?; })?;
inputs = format!("{inputs}{tool_prompt}{tools_str}");
let tool_prompt =
"Based on the conversation, please choose the most appropriate tool to use:"
.to_string();
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools_str}\n\n");
Some(GrammarType::Json(tools.into())) Some(GrammarType::Json(tools.into()))
} else { } else {
None None
@ -696,7 +682,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar, grammar: tool_grammar.clone(),
}, },
}; };
@ -760,14 +746,41 @@ async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
let (tool_calls, output) = if tool_grammar.is_some() {
// gen_text should be valid json
let gen_text_value: Value =
serde_json::from_str(&generation.generated_text).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
let tool_call = Some(ToolCall {
id: 0,
r#type: "function".to_string(),
function: Function {
description: None,
name: "tools".to_string(),
parameters: gen_text_value.get("function").unwrap().clone(),
},
});
(tool_call, None)
} else {
(None, Some(generation.generated_text))
};
// build the complete response object with the full text // build the complete response object with the full text
let response = ChatCompletion::new( let response = ChatCompletion::new(
model_id, model_id,
system_fingerprint, system_fingerprint,
generation.generated_text, output,
current_time, current_time,
generation.details.unwrap(), generation.details.unwrap(),
logprobs, logprobs,
tool_calls,
); );
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference

View File

@ -513,9 +513,6 @@ class GrammarLogitProcessor(LogitsProcessor):
start_time = time.time() start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema) schema = build_regex_from_object(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_OPTIONAL_JSON:
# TODO: use a better method to handle optional grammars
schema = f"({build_regex_from_object(schema)})|.*"
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer) fsm = RegexFSM(schema, tokenizer)