mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve tools api and add tool prompt
This commit is contained in:
parent
0e30e65822
commit
a32d3dd6cb
@ -55,7 +55,6 @@ enum GrammarType {
|
||||
GRAMMAR_TYPE_NONE = 0;
|
||||
GRAMMAR_TYPE_JSON = 1;
|
||||
GRAMMAR_TYPE_REGEX = 2;
|
||||
GRAMMAR_TYPE_OPTIONAL_JSON = 3;
|
||||
}
|
||||
|
||||
message NextTokenChooserParameters {
|
||||
|
@ -812,23 +812,27 @@ mod tests {
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
@ -877,28 +881,33 @@ mod tests {
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi again!".to_string(),
|
||||
content: Some("Hi again!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
@ -952,23 +961,27 @@ mod tests {
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
@ -1006,23 +1019,27 @@ mod tests {
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
|
@ -358,10 +358,11 @@ impl ChatCompletion {
|
||||
pub(crate) fn new(
|
||||
model: String,
|
||||
system_fingerprint: String,
|
||||
output: String,
|
||||
output: Option<String>,
|
||||
created: u64,
|
||||
details: Details,
|
||||
return_logprobs: bool,
|
||||
tool_calls: Option<ToolCall>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: String::new(),
|
||||
@ -375,6 +376,7 @@ impl ChatCompletion {
|
||||
role: "assistant".into(),
|
||||
content: output,
|
||||
name: None,
|
||||
tool_calls,
|
||||
},
|
||||
logprobs: return_logprobs
|
||||
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
||||
@ -527,10 +529,61 @@ pub(crate) struct ChatRequest {
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
|
||||
/// A prompt to be appended before the tools
|
||||
#[serde(default = "default_tool_prompt")]
|
||||
pub tool_prompt: Option<String>,
|
||||
|
||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub tool_choice: Option<String>,
|
||||
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||
pub tool_choice: Option<ToolType>,
|
||||
}
|
||||
|
||||
fn default_tool_prompt() -> Option<String> {
|
||||
Some(
|
||||
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
|
||||
)
|
||||
}
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
enum ToolType {
|
||||
FunctionName(String),
|
||||
OneOf,
|
||||
}
|
||||
|
||||
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
|
||||
mod deserialize_tool_choice {
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
|
||||
match value {
|
||||
Value::String(s) => match s.as_str() {
|
||||
"none" => Ok(None),
|
||||
"auto" => Ok(Some(ToolType::OneOf)),
|
||||
_ => Ok(Some(ToolType::FunctionName(s))),
|
||||
},
|
||||
Value::Object(map) => {
|
||||
if let Some(content) = map
|
||||
.get("function")
|
||||
.and_then(|v| v.get("name"))
|
||||
.and_then(|v| v.as_str())
|
||||
{
|
||||
Ok(Some(ToolType::FunctionName(content.to_string())))
|
||||
} else {
|
||||
Err(de::Error::custom("function key not found in tool choice"))
|
||||
}
|
||||
}
|
||||
_ => Err(de::Error::custom("invalid token format")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||
@ -575,7 +628,8 @@ impl FunctionRef {
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct Function {
|
||||
pub description: String,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
@ -597,15 +651,24 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
||||
add_generation_prompt: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ToolCall {
|
||||
pub id: u32,
|
||||
pub r#type: String,
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
pub(crate) struct Message {
|
||||
#[schema(example = "user")]
|
||||
pub role: String,
|
||||
#[schema(example = "My name is David and I")]
|
||||
pub content: String,
|
||||
pub content: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "\"David\"")]
|
||||
pub name: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<ToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
|
@ -10,7 +10,7 @@ use crate::{
|
||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
||||
};
|
||||
use crate::{FunctionRef, Tools};
|
||||
use crate::{Function, FunctionRef, ToolCall, ToolType, Tools};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
@ -583,6 +583,16 @@ async fn chat_completions(
|
||||
let logprobs = req.logprobs.unwrap_or(false);
|
||||
let seed = req.seed;
|
||||
|
||||
if stream && req.tools.is_some() {
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Tools are not supported with stream".to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||
Ok(inputs) => inputs,
|
||||
@ -599,53 +609,35 @@ async fn chat_completions(
|
||||
}
|
||||
};
|
||||
|
||||
// if theres a tools object, we need to decompose it and use the function name as the key
|
||||
// and the parameters as the value in the "$functions" object.
|
||||
let grammar = if let Some(ref req_tools) = &req.tools {
|
||||
// get the tool_choice if there is one
|
||||
let tool_choice = &req.tool_choice;
|
||||
let tools_to_use = if let Some(tool_choice) = tool_choice {
|
||||
// get the tool based on the tool_choice
|
||||
let tool = req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *tool_choice)
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Input validation error".to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
vec![tool.clone()]
|
||||
} else {
|
||||
req_tools.clone()
|
||||
};
|
||||
|
||||
let functions: HashMap<String, Value> = {
|
||||
let mut tools = HashMap::new();
|
||||
for tool in &tools_to_use {
|
||||
let func = tool.function.clone();
|
||||
let name = func.name;
|
||||
let parameters = match func.parameters.as_object() {
|
||||
Some(parameters) => parameters.clone(),
|
||||
None => {
|
||||
return Err((
|
||||
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
|
||||
let tool_prompt = req.tool_prompt.unwrap_or_default();
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *name)
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Input validation error".to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
tools.insert(name, Value::Object(parameters));
|
||||
)
|
||||
})?
|
||||
.clone()]
|
||||
}
|
||||
tools
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
|
||||
let functions: HashMap<String, Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
(func.name, func.parameters)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
function: functions,
|
||||
any_of: tools_to_use
|
||||
@ -654,7 +646,6 @@ async fn chat_completions(
|
||||
.collect(),
|
||||
};
|
||||
|
||||
// update the input
|
||||
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
@ -664,12 +655,7 @@ async fn chat_completions(
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_prompt =
|
||||
"Based on the conversation, please choose the most appropriate tool to use:"
|
||||
.to_string();
|
||||
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools_str}\n\n");
|
||||
|
||||
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
||||
Some(GrammarType::Json(tools.into()))
|
||||
} else {
|
||||
None
|
||||
@ -696,7 +682,7 @@ async fn chat_completions(
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar,
|
||||
grammar: tool_grammar.clone(),
|
||||
},
|
||||
};
|
||||
|
||||
@ -760,14 +746,41 @@ async fn chat_completions(
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let (tool_calls, output) = if tool_grammar.is_some() {
|
||||
// gen_text should be valid json
|
||||
let gen_text_value: Value =
|
||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_call = Some(ToolCall {
|
||||
id: 0,
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
description: None,
|
||||
name: "tools".to_string(),
|
||||
parameters: gen_text_value.get("function").unwrap().clone(),
|
||||
},
|
||||
});
|
||||
(tool_call, None)
|
||||
} else {
|
||||
(None, Some(generation.generated_text))
|
||||
};
|
||||
// build the complete response object with the full text
|
||||
let response = ChatCompletion::new(
|
||||
model_id,
|
||||
system_fingerprint,
|
||||
generation.generated_text,
|
||||
output,
|
||||
current_time,
|
||||
generation.details.unwrap(),
|
||||
logprobs,
|
||||
tool_calls,
|
||||
);
|
||||
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
|
@ -513,9 +513,6 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
start_time = time.time()
|
||||
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||
schema = build_regex_from_object(schema)
|
||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_OPTIONAL_JSON:
|
||||
# TODO: use a better method to handle optional grammars
|
||||
schema = f"({build_regex_from_object(schema)})|.*"
|
||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
||||
pass # schema is already a regex just here for clarity
|
||||
fsm = RegexFSM(schema, tokenizer)
|
||||
|
Loading…
Reference in New Issue
Block a user