Add test_chat_template_loop_controls to test break

This commit is contained in:
Alvaro Bartolome 2025-02-07 13:35:26 +01:00
parent 5691c91350
commit 20603881e3
No known key found for this signature in database

View File

@ -186,6 +186,72 @@ mod tests {
);
}
#[test]
fn test_chat_template_loop_controls() {
// some chat templates as e.g. CohereForAI/c4ai-command-r7b-12-202 contain `break`
// statements in their chat templates, so the feature `loop_controls` has been included
// in `minijinja`
let env = Environment::new();
let source = r#"
{% set user_count = 0 %}
{% for message in messages %}
{% if message['role'] == 'user' %}
{{'### User:\n' + message['content']+'\n\n'}}
{% set user_count = user_count + 1 %}
{% if user_count >= 2 %}
{% break %}
{% endif %}
{% elif message['role'] == 'assistant' %}
{{'### Assistant:\n' + message['content']}}
{% endif %}
{% endfor %}
{% if add_generation_prompt %}
{{ '### Assistant:\n' }}
{% endif %}"#;
// trim all the whitespace
let source = source
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");
let tmpl = env.template_from_str(&source);
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
TextMessage {
role: "user".to_string(),
content: "Hi!".to_string(),
},
TextMessage {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
},
TextMessage {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
},
TextMessage {
role: "assistant".to_string(),
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
..Default::default()
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(
result,
"### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\n"
);
}
#[test]
fn test_chat_template_invalid_with_raise() {
let mut env = Environment::new();