Add strftime_now callable function for minijinja chat templates (#2983)

* Add `chrono` and `strftime_now` function callable

* Fix `test_chat_template_valid_with_strftime_now`

* Fix `test_chat_template_valid_with_strftime_now`
This commit is contained in:
Alvaro Bartolome 2025-02-03 15:30:48 +01:00 committed by GitHub
parent e3f2018cb5
commit 88fd56f549
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 136 additions and 1 deletions

53
Cargo.lock generated
View File

@ -52,6 +52,21 @@ version = "0.2.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9"
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
name = "anstream"
version = "0.6.18"
@ -651,6 +666,20 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chrono"
version = "0.4.39"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-targets 0.52.6",
]
[[package]]
name = "clang-sys"
version = "1.8.1"
@ -1801,6 +1830,29 @@ dependencies = [
"tracing",
]
[[package]]
name = "iana-time-zone"
version = "0.1.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]]
name = "icu_collections"
version = "1.5.0"
@ -4511,6 +4563,7 @@ dependencies = [
"axum 0.7.9",
"axum-tracing-opentelemetry",
"base64 0.22.1",
"chrono",
"clap 4.5.21",
"csv",
"futures",

View File

@ -64,6 +64,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
csv = "1.3.0"
ureq = "=2.9"
pyo3 = { workspace = true }
chrono = "0.4.39"
[build-dependencies]

View File

@ -1,5 +1,6 @@
use crate::infer::InferError;
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
use chrono::Local;
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
@ -8,6 +9,11 @@ pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Err
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}
/// Get the current date in a specific format (custom function), similar to `datetime.now().strftime()` in Python
pub(crate) fn strftime_now(format_str: String) -> Result<String, minijinja::Error> {
Ok(Local::now().format(&format_str).to_string())
}
#[derive(Clone)]
pub(crate) struct ChatTemplate {
template: Template<'static, 'static>,
@ -27,6 +33,7 @@ impl ChatTemplate {
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
tracing::debug!("Loading template: {}", template_str);
// leaking env and template_str as read-only, static resources for performance.
@ -109,11 +116,12 @@ impl ChatTemplate {
// tests
#[cfg(test)]
mod tests {
use crate::infer::chat_template::raise_exception;
use crate::infer::chat_template::{raise_exception, strftime_now};
use crate::infer::ChatTemplate;
use crate::{
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
};
use chrono::Local;
use minijinja::Environment;
#[test]
@ -182,6 +190,7 @@ mod tests {
fn test_chat_template_invalid_with_raise() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
let source = r#"
{{ bos_token }}
@ -253,6 +262,7 @@ mod tests {
fn test_chat_template_valid_with_raise() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
let source = r#"
{{ bos_token }}
@ -307,10 +317,79 @@ mod tests {
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
}
#[test]
fn test_chat_template_valid_with_strftime_now() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
let source = r#"
{% set today = strftime_now("%Y-%m-%d") %}
{% set default_system_message = "The current date is " + today + "." %}
{{ bos_token }}
{% if messages[0]['role'] == 'system' %}
{ set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{% else %}
{%- set system_message = default_system_message %}
{%- set loop_messages = messages %}
{% endif %}
{{ '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}
{% for message in loop_messages %}
{% if message['role'] == 'user' %}
{{ '[INST]' + message['content'] + '[/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ message['content'] + eos_token }}
{% else %}
{{ raise_exception('Only user and assistant roles are supported!') }}
{% endif %}
{% endfor %}
"#;
// trim all the whitespace
let source = source
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");
let tmpl = env.template_from_str(&source);
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
TextMessage {
role: "user".to_string(),
content: "Hi!".to_string(),
},
TextMessage {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
},
TextMessage {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
},
TextMessage {
role: "assistant".to_string(),
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
..Default::default()
};
let current_date = Local::now().format("%Y-%m-%d").to_string();
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(result, format!("[BOS][SYSTEM_PROMPT]The current date is {}.[/SYSTEM_PROMPT][INST]Hi![/INST]Hello how can I help?[EOS][INST]What is Deep Learning?[/INST]magic![EOS]", current_date));
}
#[test]
fn test_chat_template_valid_with_add_generation_prompt() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
let source = r#"
{% for message in messages %}
@ -502,6 +581,7 @@ mod tests {
{
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
let tmpl = env.template_from_str(chat_template);
let result = tmpl.unwrap().render(input).unwrap();
assert_eq!(result, target);
@ -776,6 +856,7 @@ mod tests {
{
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
// trim all the whitespace
let chat_template = chat_template
.lines()