feat: update and add tests for add_generation_prompt

This commit is contained in:
drbh 2024-02-06 11:43:43 -05:00
parent ff0428a351
commit 53b6b8bd08

View File

@ -807,22 +807,14 @@ mod tests {
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"), eos_token: Some("[EOS]"),
add_generation_prompt: false, add_generation_prompt: true,
}; };
let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!( assert_eq!(
result, result,
r#"### User: "### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n"
Hi!
### Assistant:
Hello how can I help?### User:
What is Deep Learning?
### Assistant:
magic!"#
); );
} }
@ -880,7 +872,7 @@ magic!"#
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"), eos_token: Some("[EOS]"),
add_generation_prompt: false, add_generation_prompt: true,
}; };
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
@ -946,10 +938,60 @@ magic!"#
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"), eos_token: Some("[EOS]"),
add_generation_prompt: false, add_generation_prompt: true,
}; };
let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]"); assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
} }
#[test]
fn test_chat_template_valid_with_add_generation_prompt() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
let source = r#"
{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
{% endfor %}
{% if add_generation_prompt %}
{{ '<|im_start|>assistant\n' }}
{% endif %}"#;
// trim all the whitespace
let source = source
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");
let tmpl = env.template_from_str(&source);
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
Message {
role: "user".to_string(),
content: "Hi!".to_string(),
},
Message {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
},
Message {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
},
Message {
role: "assistant".to_string(),
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n");
}
} }