mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
feat: improve default tool serialization and lints
This commit is contained in:
parent
2ee98c7c07
commit
9ea34977ac
@ -65,11 +65,14 @@ impl ChatTemplate {
|
|||||||
|
|
||||||
let (tools, tool_prompt) = tools_and_prompt.unwrap_or_default();
|
let (tools, tool_prompt) = tools_and_prompt.unwrap_or_default();
|
||||||
|
|
||||||
if tools.is_some() {
|
if let Some(ref tools) = tools {
|
||||||
// check if the `tools` variable is used in the template
|
// check if the `tools` variable is used in the template
|
||||||
// if not, we need to append the tools to the last message
|
// if not, we need to append the tools to the last message
|
||||||
let text = if self.use_default_tool_template {
|
let text = if self.use_default_tool_template {
|
||||||
format!("\n---\n{:?}\n{}", tools, tool_prompt)
|
match serde_json::to_string(tools) {
|
||||||
|
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
|
||||||
|
Err(e) => return Err(InferError::ToolError(e.to_string())),
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// if the `tools` variable is used in the template, we just append the tool_prompt
|
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||||
format!("\n---\n{}", tool_prompt)
|
format!("\n---\n{}", tool_prompt)
|
||||||
@ -81,18 +84,17 @@ impl ChatTemplate {
|
|||||||
|
|
||||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
|
|
||||||
return self
|
self.template
|
||||||
.template
|
|
||||||
.render(ChatTemplateInputs {
|
.render(ChatTemplateInputs {
|
||||||
guideline,
|
guideline,
|
||||||
messages,
|
messages,
|
||||||
bos_token: self.bos_token.as_deref(),
|
bos_token: self.bos_token.as_deref(),
|
||||||
eos_token: self.eos_token.as_deref(),
|
eos_token: self.eos_token.as_deref(),
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
tools: tools,
|
tools,
|
||||||
tools_prompt: None,
|
tools_prompt: None,
|
||||||
})
|
})
|
||||||
.map_err(InferError::TemplateError);
|
.map_err(InferError::TemplateError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,7 +104,8 @@ mod tests {
|
|||||||
use crate::infer::chat_template::raise_exception;
|
use crate::infer::chat_template::raise_exception;
|
||||||
use crate::infer::ChatTemplate;
|
use crate::infer::ChatTemplate;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken,
|
ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage,
|
||||||
|
TokenizerConfigToken, Tool,
|
||||||
};
|
};
|
||||||
use minijinja::Environment;
|
use minijinja::Environment;
|
||||||
|
|
||||||
@ -861,11 +864,12 @@ mod tests {
|
|||||||
content: MessageContent::SingleText("Just testing".to_string()),
|
content: MessageContent::SingleText("Just testing".to_string()),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
let tools = serde_json::json!("[]");
|
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||||
|
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
||||||
let tool_prompt = "This default prompt will be used".to_string();
|
let tool_prompt = "This default prompt will be used".to_string();
|
||||||
let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt);
|
let tools_and_prompt = Some((Some(tools), tool_prompt));
|
||||||
let result = ct.apply(None, msgs, Some(grammer_with_prompt));
|
let result = ct.apply(None, msgs, tools_and_prompt);
|
||||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string();
|
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
|
||||||
assert_eq!(result.unwrap(), expected);
|
assert_eq!(result.unwrap(), expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1207,6 +1207,7 @@ pub(crate) struct GenerateResponse {
|
|||||||
pub(crate) struct ChatTokenizeResponse {
|
pub(crate) struct ChatTokenizeResponse {
|
||||||
pub(crate) tokenize_response: TokenizeResponse,
|
pub(crate) tokenize_response: TokenizeResponse,
|
||||||
pub(crate) templated_text: String,
|
pub(crate) templated_text: String,
|
||||||
|
pub(crate) using_tools: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
@ -146,7 +146,7 @@ async fn get_chat_tokenize(
|
|||||||
} = req;
|
} = req;
|
||||||
|
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||||
let (inputs, _grammar, _tool_grammar) = prepare_chat_input(
|
let (inputs, _grammar, using_tools) = prepare_chat_input(
|
||||||
&infer,
|
&infer,
|
||||||
response_format,
|
response_format,
|
||||||
tools,
|
tools,
|
||||||
@ -206,6 +206,7 @@ async fn get_chat_tokenize(
|
|||||||
let resp = ChatTokenizeResponse {
|
let resp = ChatTokenizeResponse {
|
||||||
tokenize_response: TokenizeResponse(tokens),
|
tokenize_response: TokenizeResponse(tokens),
|
||||||
templated_text: input,
|
templated_text: input,
|
||||||
|
using_tools,
|
||||||
};
|
};
|
||||||
Ok((HeaderMap::new(), Json(resp)))
|
Ok((HeaderMap::new(), Json(resp)))
|
||||||
} else {
|
} else {
|
||||||
@ -1165,7 +1166,7 @@ async fn chat_completions(
|
|||||||
Some(temperature) if temperature == 0.0 => (false, None),
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
other => (true, other),
|
other => (true, other),
|
||||||
};
|
};
|
||||||
let (inputs, grammar, tool_grammar) = prepare_chat_input(
|
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||||
&infer,
|
&infer,
|
||||||
response_format,
|
response_format,
|
||||||
tools,
|
tools,
|
||||||
@ -1221,7 +1222,7 @@ async fn chat_completions(
|
|||||||
});
|
});
|
||||||
|
|
||||||
// replace the content with the tool calls if grammar is present
|
// replace the content with the tool calls if grammar is present
|
||||||
let (content, tool_calls) = if tool_grammar.is_some() {
|
let (content, tool_calls) = if using_tools {
|
||||||
(None, Some(vec![stream_token.token.text]))
|
(None, Some(vec![stream_token.token.text]))
|
||||||
} else {
|
} else {
|
||||||
let content = if !stream_token.token.special {
|
let content = if !stream_token.token.special {
|
||||||
@ -1275,7 +1276,7 @@ async fn chat_completions(
|
|||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
let (tool_calls, output) = if tool_grammar.is_some() {
|
let (tool_calls, output) = if using_tools {
|
||||||
let gen_text_value: Value =
|
let gen_text_value: Value =
|
||||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||||
InferError::ToolError(format!(
|
InferError::ToolError(format!(
|
||||||
@ -2539,7 +2540,7 @@ fn create_post_processor(
|
|||||||
Ok(post_processor)
|
Ok(post_processor)
|
||||||
}
|
}
|
||||||
|
|
||||||
type PreparedInput = (String, Option<GrammarType>, Option<Tools>);
|
type PreparedInput = (String, Option<GrammarType>, bool);
|
||||||
|
|
||||||
fn prepare_chat_input(
|
fn prepare_chat_input(
|
||||||
infer: &Infer,
|
infer: &Infer,
|
||||||
@ -2558,7 +2559,7 @@ fn prepare_chat_input(
|
|||||||
|
|
||||||
if let Some(format) = response_format {
|
if let Some(format) = response_format {
|
||||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||||
return Ok((inputs, Some(format), None));
|
return Ok((inputs, Some(format), false));
|
||||||
}
|
}
|
||||||
|
|
||||||
// if tools are set, apply the tool grammar and then the chat template
|
// if tools are set, apply the tool grammar and then the chat template
|
||||||
@ -2568,5 +2569,119 @@ fn prepare_chat_input(
|
|||||||
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||||
let tools_and_prompt: (Option<Vec<Tool>>, String) = (tools, tool_prompt.into());
|
let tools_and_prompt: (Option<Vec<Tool>>, String) = (tools, tool_prompt.into());
|
||||||
let inputs = infer.apply_chat_template(guideline, messages, Some(tools_and_prompt))?;
|
let inputs = infer.apply_chat_template(guideline, messages, Some(tools_and_prompt))?;
|
||||||
Ok((inputs, grammar, tool_grammar))
|
Ok((inputs, grammar, tool_grammar.is_some()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::ChatTemplateVersions;
|
||||||
|
use crate::FunctionsMap;
|
||||||
|
use crate::HubTokenizerConfig;
|
||||||
|
use crate::Properties;
|
||||||
|
use crate::TokenizerConfigToken;
|
||||||
|
use crate::Tool;
|
||||||
|
use crate::Tools;
|
||||||
|
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_prepare_chat_input() {
|
||||||
|
// Mock Backend to avoid network requests
|
||||||
|
struct MockBackend;
|
||||||
|
|
||||||
|
impl Backend for MockBackend {
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
request: crate::validation::ValidGenerateRequest,
|
||||||
|
) -> Result<
|
||||||
|
tokio_stream::wrappers::UnboundedReceiverStream<
|
||||||
|
Result<InferStreamResponse, InferError>,
|
||||||
|
>,
|
||||||
|
InferError,
|
||||||
|
> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
fn health<'life0, 'async_trait>(
|
||||||
|
&'life0 self,
|
||||||
|
current_health: bool,
|
||||||
|
) -> core::pin::Pin<
|
||||||
|
Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>,
|
||||||
|
>
|
||||||
|
where
|
||||||
|
'life0: 'async_trait,
|
||||||
|
Self: 'async_trait,
|
||||||
|
{
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let backend = MockBackend {};
|
||||||
|
|
||||||
|
let mut tokenizer_config = HubTokenizerConfig::default();
|
||||||
|
|
||||||
|
// mock tokenizer config values
|
||||||
|
tokenizer_config.bos_token = Some(TokenizerConfigToken::String("<s>".to_string()));
|
||||||
|
tokenizer_config.eos_token = Some(TokenizerConfigToken::String("</s>".to_string()));
|
||||||
|
tokenizer_config.chat_template = Some(
|
||||||
|
ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
let infer = Infer::new(
|
||||||
|
backend,
|
||||||
|
Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false),
|
||||||
|
1,
|
||||||
|
tokenizer_config,
|
||||||
|
HubProcessorConfig::default(),
|
||||||
|
);
|
||||||
|
let response_format = None;
|
||||||
|
let tools = Some(vec![Tool {
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionDefinition {
|
||||||
|
name: "get_current_weather".to_string(),
|
||||||
|
description: Some("Get the current weather".to_string()),
|
||||||
|
arguments: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "format"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}]);
|
||||||
|
let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.";
|
||||||
|
let guideline = None;
|
||||||
|
let messages = vec![Message {
|
||||||
|
name: None,
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: MessageContent::SingleText(
|
||||||
|
"What is the weather like in New York?".to_string(),
|
||||||
|
),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let result = prepare_chat_input(
|
||||||
|
&infer,
|
||||||
|
response_format,
|
||||||
|
tools,
|
||||||
|
ToolChoice(None),
|
||||||
|
tool_prompt,
|
||||||
|
guideline,
|
||||||
|
messages,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let (inputs, grammar, using_tools) = result.unwrap();
|
||||||
|
assert_eq!(using_tools, true);
|
||||||
|
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user