mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-04 05:32:11 +00:00
Add strftime_now
callable function for minijinja
chat templates (#2983)
* Add `chrono` and `strftime_now` function callable * Fix `test_chat_template_valid_with_strftime_now` * Fix `test_chat_template_valid_with_strftime_now`
This commit is contained in:
parent
e3f2018cb5
commit
88fd56f549
53
Cargo.lock
generated
53
Cargo.lock
generated
@ -52,6 +52,21 @@ version = "0.2.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9"
|
||||
|
||||
[[package]]
|
||||
name = "android-tzdata"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.18"
|
||||
@ -651,6 +666,20 @@ version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.39"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825"
|
||||
dependencies = [
|
||||
"android-tzdata",
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits",
|
||||
"wasm-bindgen",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clang-sys"
|
||||
version = "1.8.1"
|
||||
@ -1801,6 +1830,29 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.61"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220"
|
||||
dependencies = [
|
||||
"android_system_properties",
|
||||
"core-foundation-sys",
|
||||
"iana-time-zone-haiku",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"windows-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone-haiku"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icu_collections"
|
||||
version = "1.5.0"
|
||||
@ -4511,6 +4563,7 @@ dependencies = [
|
||||
"axum 0.7.9",
|
||||
"axum-tracing-opentelemetry",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"clap 4.5.21",
|
||||
"csv",
|
||||
"futures",
|
||||
|
@ -64,6 +64,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
|
||||
csv = "1.3.0"
|
||||
ureq = "=2.9"
|
||||
pyo3 = { workspace = true }
|
||||
chrono = "0.4.39"
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::infer::InferError;
|
||||
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
|
||||
use chrono::Local;
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
|
||||
@ -8,6 +9,11 @@ pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Err
|
||||
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||
}
|
||||
|
||||
/// Get the current date in a specific format (custom function), similar to `datetime.now().strftime()` in Python
|
||||
pub(crate) fn strftime_now(format_str: String) -> Result<String, minijinja::Error> {
|
||||
Ok(Local::now().format(&format_str).to_string())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ChatTemplate {
|
||||
template: Template<'static, 'static>,
|
||||
@ -27,6 +33,7 @@ impl ChatTemplate {
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
env.add_function("strftime_now", strftime_now);
|
||||
tracing::debug!("Loading template: {}", template_str);
|
||||
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
@ -109,11 +116,12 @@ impl ChatTemplate {
|
||||
// tests
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::infer::chat_template::raise_exception;
|
||||
use crate::infer::chat_template::{raise_exception, strftime_now};
|
||||
use crate::infer::ChatTemplate;
|
||||
use crate::{
|
||||
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
|
||||
};
|
||||
use chrono::Local;
|
||||
use minijinja::Environment;
|
||||
|
||||
#[test]
|
||||
@ -182,6 +190,7 @@ mod tests {
|
||||
fn test_chat_template_invalid_with_raise() {
|
||||
let mut env = Environment::new();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
env.add_function("strftime_now", strftime_now);
|
||||
|
||||
let source = r#"
|
||||
{{ bos_token }}
|
||||
@ -253,6 +262,7 @@ mod tests {
|
||||
fn test_chat_template_valid_with_raise() {
|
||||
let mut env = Environment::new();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
env.add_function("strftime_now", strftime_now);
|
||||
|
||||
let source = r#"
|
||||
{{ bos_token }}
|
||||
@ -307,10 +317,79 @@ mod tests {
|
||||
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_valid_with_strftime_now() {
|
||||
let mut env = Environment::new();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
env.add_function("strftime_now", strftime_now);
|
||||
|
||||
let source = r#"
|
||||
{% set today = strftime_now("%Y-%m-%d") %}
|
||||
{% set default_system_message = "The current date is " + today + "." %}
|
||||
{{ bos_token }}
|
||||
{% if messages[0]['role'] == 'system' %}
|
||||
{ set system_message = messages[0]['content'] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{% else %}
|
||||
{%- set system_message = default_system_message %}
|
||||
{%- set loop_messages = messages %}
|
||||
{% endif %}
|
||||
{{ '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}
|
||||
{% for message in loop_messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
{{ '[INST]' + message['content'] + '[/INST]' }}
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
{{ message['content'] + eos_token }}
|
||||
{% else %}
|
||||
{{ raise_exception('Only user and assistant roles are supported!') }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
"#;
|
||||
|
||||
// trim all the whitespace
|
||||
let source = source
|
||||
.lines()
|
||||
.map(|line| line.trim())
|
||||
.collect::<Vec<&str>>()
|
||||
.join("");
|
||||
|
||||
let tmpl = env.template_from_str(&source);
|
||||
|
||||
let chat_template_inputs = ChatTemplateInputs {
|
||||
messages: vec![
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
},
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
},
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
},
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let current_date = Local::now().format("%Y-%m-%d").to_string();
|
||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||
assert_eq!(result, format!("[BOS][SYSTEM_PROMPT]The current date is {}.[/SYSTEM_PROMPT][INST]Hi![/INST]Hello how can I help?[EOS][INST]What is Deep Learning?[/INST]magic![EOS]", current_date));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_valid_with_add_generation_prompt() {
|
||||
let mut env = Environment::new();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
env.add_function("strftime_now", strftime_now);
|
||||
|
||||
let source = r#"
|
||||
{% for message in messages %}
|
||||
@ -502,6 +581,7 @@ mod tests {
|
||||
{
|
||||
let mut env = Environment::new();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
env.add_function("strftime_now", strftime_now);
|
||||
let tmpl = env.template_from_str(chat_template);
|
||||
let result = tmpl.unwrap().render(input).unwrap();
|
||||
assert_eq!(result, target);
|
||||
@ -776,6 +856,7 @@ mod tests {
|
||||
{
|
||||
let mut env = Environment::new();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
env.add_function("strftime_now", strftime_now);
|
||||
// trim all the whitespace
|
||||
let chat_template = chat_template
|
||||
.lines()
|
||||
|
Loading…
Reference in New Issue
Block a user