From 88fd56f549637b36529582b682c4c76eb78e36a5 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Mon, 3 Feb 2025 15:30:48 +0100 Subject: [PATCH] Add `strftime_now` callable function for `minijinja` chat templates (#2983) * Add `chrono` and `strftime_now` function callable * Fix `test_chat_template_valid_with_strftime_now` * Fix `test_chat_template_valid_with_strftime_now` --- Cargo.lock | 53 ++++++++++++++++++++ router/Cargo.toml | 1 + router/src/infer/chat_template.rs | 83 ++++++++++++++++++++++++++++++- 3 files changed, 136 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index d6883f9d..915de0d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,6 +52,21 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.18" @@ -651,6 +666,20 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chrono" +version = "0.4.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets 0.52.6", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -1801,6 +1830,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "1.5.0" @@ -4511,6 +4563,7 @@ dependencies = [ "axum 0.7.9", "axum-tracing-opentelemetry", "base64 0.22.1", + "chrono", "clap 4.5.21", "csv", "futures", diff --git a/router/Cargo.toml b/router/Cargo.toml index 2e621dfc..e4d0179a 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -64,6 +64,7 @@ uuid = { version = "1.9.1", default-features = false, features = [ csv = "1.3.0" ureq = "=2.9" pyo3 = { workspace = true } +chrono = "0.4.39" [build-dependencies] diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 2bda7193..8303ee76 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1,5 +1,6 @@ use crate::infer::InferError; use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; +use chrono::Local; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -8,6 +9,11 @@ pub(crate) fn raise_exception(err_text: String) -> Result Result { + Ok(Local::now().format(&format_str).to_string()) +} + #[derive(Clone)] pub(crate) struct ChatTemplate { template: Template<'static, 'static>, @@ -27,6 +33,7 @@ impl ChatTemplate { env.set_unknown_method_callback(pycompat::unknown_method_callback); let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); tracing::debug!("Loading template: {}", template_str); // leaking env and template_str as read-only, static resources for performance. @@ -109,11 +116,12 @@ impl ChatTemplate { // tests #[cfg(test)] mod tests { - use crate::infer::chat_template::raise_exception; + use crate::infer::chat_template::{raise_exception, strftime_now}; use crate::infer::ChatTemplate; use crate::{ ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool, }; + use chrono::Local; use minijinja::Environment; #[test] @@ -182,6 +190,7 @@ mod tests { fn test_chat_template_invalid_with_raise() { let mut env = Environment::new(); env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); let source = r#" {{ bos_token }} @@ -253,6 +262,7 @@ mod tests { fn test_chat_template_valid_with_raise() { let mut env = Environment::new(); env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); let source = r#" {{ bos_token }} @@ -307,10 +317,79 @@ mod tests { assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]"); } + #[test] + fn test_chat_template_valid_with_strftime_now() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); + + let source = r#" + {% set today = strftime_now("%Y-%m-%d") %} + {% set default_system_message = "The current date is " + today + "." %} + {{ bos_token }} + {% if messages[0]['role'] == 'system' %} + { set system_message = messages[0]['content'] %} + {%- set loop_messages = messages[1:] %} + {% else %} + {%- set system_message = default_system_message %} + {%- set loop_messages = messages %} + {% endif %} + {{ '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }} + {% for message in loop_messages %} + {% if message['role'] == 'user' %} + {{ '[INST]' + message['content'] + '[/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ message['content'] + eos_token }} + {% else %} + {{ raise_exception('Only user and assistant roles are supported!') }} + {% endif %} + {% endfor %} + "#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + ..Default::default() + }; + + let current_date = Local::now().format("%Y-%m-%d").to_string(); + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + assert_eq!(result, format!("[BOS][SYSTEM_PROMPT]The current date is {}.[/SYSTEM_PROMPT][INST]Hi![/INST]Hello how can I help?[EOS][INST]What is Deep Learning?[/INST]magic![EOS]", current_date)); + } + #[test] fn test_chat_template_valid_with_add_generation_prompt() { let mut env = Environment::new(); env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); let source = r#" {% for message in messages %} @@ -502,6 +581,7 @@ mod tests { { let mut env = Environment::new(); env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); let tmpl = env.template_from_str(chat_template); let result = tmpl.unwrap().render(input).unwrap(); assert_eq!(result, target); @@ -776,6 +856,7 @@ mod tests { { let mut env = Environment::new(); env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); // trim all the whitespace let chat_template = chat_template .lines()