mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve tools api and add tool prompt
This commit is contained in:
parent
0e30e65822
commit
a32d3dd6cb
@ -55,7 +55,6 @@ enum GrammarType {
|
|||||||
GRAMMAR_TYPE_NONE = 0;
|
GRAMMAR_TYPE_NONE = 0;
|
||||||
GRAMMAR_TYPE_JSON = 1;
|
GRAMMAR_TYPE_JSON = 1;
|
||||||
GRAMMAR_TYPE_REGEX = 2;
|
GRAMMAR_TYPE_REGEX = 2;
|
||||||
GRAMMAR_TYPE_OPTIONAL_JSON = 3;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message NextTokenChooserParameters {
|
message NextTokenChooserParameters {
|
||||||
|
@ -812,23 +812,27 @@ mod tests {
|
|||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: Some("Hi!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: Some("Hello how can I help?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: Some("What is Deep Learning?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: Some("magic!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -877,28 +881,33 @@ mod tests {
|
|||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: Some("Hi!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi again!".to_string(),
|
content: Some("Hi again!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: Some("Hello how can I help?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: Some("What is Deep Learning?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: Some("magic!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -952,23 +961,27 @@ mod tests {
|
|||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: Some("Hi!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: Some("Hello how can I help?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: Some("What is Deep Learning?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: Some("magic!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -1006,23 +1019,27 @@ mod tests {
|
|||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: Some("Hi!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: Some("Hello how can I help?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: Some("What is Deep Learning?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: Some("magic!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
|
@ -358,10 +358,11 @@ impl ChatCompletion {
|
|||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
model: String,
|
model: String,
|
||||||
system_fingerprint: String,
|
system_fingerprint: String,
|
||||||
output: String,
|
output: Option<String>,
|
||||||
created: u64,
|
created: u64,
|
||||||
details: Details,
|
details: Details,
|
||||||
return_logprobs: bool,
|
return_logprobs: bool,
|
||||||
|
tool_calls: Option<ToolCall>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: String::new(),
|
id: String::new(),
|
||||||
@ -375,6 +376,7 @@ impl ChatCompletion {
|
|||||||
role: "assistant".into(),
|
role: "assistant".into(),
|
||||||
content: output,
|
content: output,
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls,
|
||||||
},
|
},
|
||||||
logprobs: return_logprobs
|
logprobs: return_logprobs
|
||||||
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
||||||
@ -527,10 +529,61 @@ pub(crate) struct ChatRequest {
|
|||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
pub tools: Option<Vec<Tool>>,
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
|
||||||
|
/// A prompt to be appended before the tools
|
||||||
|
#[serde(default = "default_tool_prompt")]
|
||||||
|
pub tool_prompt: Option<String>,
|
||||||
|
|
||||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
pub tool_choice: Option<String>,
|
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||||
|
pub tool_choice: Option<ToolType>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_tool_prompt() -> Option<String> {
|
||||||
|
Some(
|
||||||
|
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
enum ToolType {
|
||||||
|
FunctionName(String),
|
||||||
|
OneOf,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
|
||||||
|
mod deserialize_tool_choice {
|
||||||
|
use super::*;
|
||||||
|
use serde::de;
|
||||||
|
use serde::Deserializer;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Value::deserialize(deserializer)?;
|
||||||
|
|
||||||
|
match value {
|
||||||
|
Value::String(s) => match s.as_str() {
|
||||||
|
"none" => Ok(None),
|
||||||
|
"auto" => Ok(Some(ToolType::OneOf)),
|
||||||
|
_ => Ok(Some(ToolType::FunctionName(s))),
|
||||||
|
},
|
||||||
|
Value::Object(map) => {
|
||||||
|
if let Some(content) = map
|
||||||
|
.get("function")
|
||||||
|
.and_then(|v| v.get("name"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
{
|
||||||
|
Ok(Some(ToolType::FunctionName(content.to_string())))
|
||||||
|
} else {
|
||||||
|
Err(de::Error::custom("function key not found in tool choice"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(de::Error::custom("invalid token format")),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||||
@ -575,7 +628,8 @@ impl FunctionRef {
|
|||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct Function {
|
pub(crate) struct Function {
|
||||||
pub description: String,
|
#[serde(default)]
|
||||||
|
pub description: Option<String>,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub parameters: serde_json::Value,
|
pub parameters: serde_json::Value,
|
||||||
}
|
}
|
||||||
@ -597,15 +651,24 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub(crate) struct ToolCall {
|
||||||
|
pub id: u32,
|
||||||
|
pub r#type: String,
|
||||||
|
pub function: Function,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
pub(crate) struct Message {
|
pub(crate) struct Message {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
pub role: String,
|
||||||
#[schema(example = "My name is David and I")]
|
#[schema(example = "My name is David and I")]
|
||||||
pub content: String,
|
pub content: Option<String>,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "\"David\"")]
|
#[schema(example = "\"David\"")]
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<ToolCall>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
|
@ -10,7 +10,7 @@ use crate::{
|
|||||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{FunctionRef, Tools};
|
use crate::{Function, FunctionRef, ToolCall, ToolType, Tools};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
@ -583,6 +583,16 @@ async fn chat_completions(
|
|||||||
let logprobs = req.logprobs.unwrap_or(false);
|
let logprobs = req.logprobs.unwrap_or(false);
|
||||||
let seed = req.seed;
|
let seed = req.seed;
|
||||||
|
|
||||||
|
if stream && req.tools.is_some() {
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Tools are not supported with stream".to_string(),
|
||||||
|
error_type: "Input validation error".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
let mut inputs = match infer.apply_chat_template(req.messages) {
|
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||||
Ok(inputs) => inputs,
|
Ok(inputs) => inputs,
|
||||||
@ -599,16 +609,13 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// if theres a tools object, we need to decompose it and use the function name as the key
|
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
|
||||||
// and the parameters as the value in the "$functions" object.
|
let tool_prompt = req.tool_prompt.unwrap_or_default();
|
||||||
let grammar = if let Some(ref req_tools) = &req.tools {
|
let tools_to_use = match tool_choice {
|
||||||
// get the tool_choice if there is one
|
ToolType::FunctionName(name) => {
|
||||||
let tool_choice = &req.tool_choice;
|
vec![req_tools
|
||||||
let tools_to_use = if let Some(tool_choice) = tool_choice {
|
|
||||||
// get the tool based on the tool_choice
|
|
||||||
let tool = req_tools
|
|
||||||
.iter()
|
.iter()
|
||||||
.find(|tool| tool.function.name == *tool_choice)
|
.find(|tool| tool.function.name == *name)
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
(
|
(
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
@ -617,34 +624,19 @@ async fn chat_completions(
|
|||||||
error_type: "Input validation error".to_string(),
|
error_type: "Input validation error".to_string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
})?;
|
})?
|
||||||
vec![tool.clone()]
|
.clone()]
|
||||||
} else {
|
}
|
||||||
req_tools.clone()
|
ToolType::OneOf => req_tools.to_owned(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let functions: HashMap<String, Value> = {
|
let functions: HashMap<String, Value> = tools_to_use
|
||||||
let mut tools = HashMap::new();
|
.iter()
|
||||||
for tool in &tools_to_use {
|
.map(|tool| {
|
||||||
let func = tool.function.clone();
|
let func = tool.function.clone();
|
||||||
let name = func.name;
|
(func.name, func.parameters)
|
||||||
let parameters = match func.parameters.as_object() {
|
})
|
||||||
Some(parameters) => parameters.clone(),
|
.collect();
|
||||||
None => {
|
|
||||||
return Err((
|
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "Input validation error".to_string(),
|
|
||||||
error_type: "Input validation error".to_string(),
|
|
||||||
}),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
tools.insert(name, Value::Object(parameters));
|
|
||||||
}
|
|
||||||
tools
|
|
||||||
};
|
|
||||||
|
|
||||||
let tools = Tools {
|
let tools = Tools {
|
||||||
function: functions,
|
function: functions,
|
||||||
@ -654,7 +646,6 @@ async fn chat_completions(
|
|||||||
.collect(),
|
.collect(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// update the input
|
|
||||||
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||||
(
|
(
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
@ -664,12 +655,7 @@ async fn chat_completions(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
||||||
let tool_prompt =
|
|
||||||
"Based on the conversation, please choose the most appropriate tool to use:"
|
|
||||||
.to_string();
|
|
||||||
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools_str}\n\n");
|
|
||||||
|
|
||||||
Some(GrammarType::Json(tools.into()))
|
Some(GrammarType::Json(tools.into()))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -696,7 +682,7 @@ async fn chat_completions(
|
|||||||
decoder_input_details: !stream,
|
decoder_input_details: !stream,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar,
|
grammar: tool_grammar.clone(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -760,14 +746,41 @@ async fn chat_completions(
|
|||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
|
let (tool_calls, output) = if tool_grammar.is_some() {
|
||||||
|
// gen_text should be valid json
|
||||||
|
let gen_text_value: Value =
|
||||||
|
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||||
|
(
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: e.to_string(),
|
||||||
|
error_type: "Input validation error".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let tool_call = Some(ToolCall {
|
||||||
|
id: 0,
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
description: None,
|
||||||
|
name: "tools".to_string(),
|
||||||
|
parameters: gen_text_value.get("function").unwrap().clone(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
(tool_call, None)
|
||||||
|
} else {
|
||||||
|
(None, Some(generation.generated_text))
|
||||||
|
};
|
||||||
// build the complete response object with the full text
|
// build the complete response object with the full text
|
||||||
let response = ChatCompletion::new(
|
let response = ChatCompletion::new(
|
||||||
model_id,
|
model_id,
|
||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
generation.generated_text,
|
output,
|
||||||
current_time,
|
current_time,
|
||||||
generation.details.unwrap(),
|
generation.details.unwrap(),
|
||||||
logprobs,
|
logprobs,
|
||||||
|
tool_calls,
|
||||||
);
|
);
|
||||||
|
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
|
@ -513,9 +513,6 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||||
schema = build_regex_from_object(schema)
|
schema = build_regex_from_object(schema)
|
||||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_OPTIONAL_JSON:
|
|
||||||
# TODO: use a better method to handle optional grammars
|
|
||||||
schema = f"({build_regex_from_object(schema)})|.*"
|
|
||||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
||||||
pass # schema is already a regex just here for clarity
|
pass # schema is already a regex just here for clarity
|
||||||
fsm = RegexFSM(schema, tokenizer)
|
fsm = RegexFSM(schema, tokenizer)
|
||||||
|
Loading…
Reference in New Issue
Block a user