mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
Add loop_controls
feature to minijinja
to handle {% break %}
(#2998)
* Add `loop_controls` feature to `minijinja` * Add `test_chat_template_loop_controls` to test `break`
This commit is contained in:
parent
794ec58b75
commit
8a1cfd6122
@ -48,7 +48,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||
"opentelemetry-otlp",
|
||||
] }
|
||||
minijinja = { workspace = true }
|
||||
minijinja = { workspace = true, features = ["loop_controls"] }
|
||||
minijinja-contrib = { workspace = true }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
|
@ -186,6 +186,72 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_loop_controls() {
|
||||
// some chat templates as e.g. CohereForAI/c4ai-command-r7b-12-202 contain `break`
|
||||
// statements in their chat templates, so the feature `loop_controls` has been included
|
||||
// in `minijinja`
|
||||
let env = Environment::new();
|
||||
|
||||
let source = r#"
|
||||
{% set user_count = 0 %}
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
{{'### User:\n' + message['content']+'\n\n'}}
|
||||
{% set user_count = user_count + 1 %}
|
||||
{% if user_count >= 2 %}
|
||||
{% break %}
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
{{'### Assistant:\n' + message['content']}}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt %}
|
||||
{{ '### Assistant:\n' }}
|
||||
{% endif %}"#;
|
||||
|
||||
// trim all the whitespace
|
||||
let source = source
|
||||
.lines()
|
||||
.map(|line| line.trim())
|
||||
.collect::<Vec<&str>>()
|
||||
.join("");
|
||||
|
||||
let tmpl = env.template_from_str(&source);
|
||||
|
||||
let chat_template_inputs = ChatTemplateInputs {
|
||||
messages: vec![
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
},
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
},
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
},
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
"### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\n"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_invalid_with_raise() {
|
||||
let mut env = Environment::new();
|
||||
|
Loading…
Reference in New Issue
Block a user