mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Add test_chat_template_loop_controls
to test break
This commit is contained in:
parent
5691c91350
commit
20603881e3
@ -186,6 +186,72 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_loop_controls() {
|
||||
// some chat templates as e.g. CohereForAI/c4ai-command-r7b-12-202 contain `break`
|
||||
// statements in their chat templates, so the feature `loop_controls` has been included
|
||||
// in `minijinja`
|
||||
let env = Environment::new();
|
||||
|
||||
let source = r#"
|
||||
{% set user_count = 0 %}
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
{{'### User:\n' + message['content']+'\n\n'}}
|
||||
{% set user_count = user_count + 1 %}
|
||||
{% if user_count >= 2 %}
|
||||
{% break %}
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
{{'### Assistant:\n' + message['content']}}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt %}
|
||||
{{ '### Assistant:\n' }}
|
||||
{% endif %}"#;
|
||||
|
||||
// trim all the whitespace
|
||||
let source = source
|
||||
.lines()
|
||||
.map(|line| line.trim())
|
||||
.collect::<Vec<&str>>()
|
||||
.join("");
|
||||
|
||||
let tmpl = env.template_from_str(&source);
|
||||
|
||||
let chat_template_inputs = ChatTemplateInputs {
|
||||
messages: vec![
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
},
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
},
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
},
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
"### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\n"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_invalid_with_raise() {
|
||||
let mut env = Environment::new();
|
||||
|
Loading…
Reference in New Issue
Block a user