fix: improve missing template var test

This commit is contained in:
drbh 2024-08-12 14:35:42 +00:00
parent 2551456fff
commit 298efa41c5
2 changed files with 32 additions and 48 deletions

View File

@ -93,7 +93,8 @@ impl ChatTemplate {
#[cfg(test)]
mod tests {
use crate::infer::chat_template::raise_exception;
use crate::{ChatTemplateInputs, TextMessage};
use crate::infer::ChatTemplate;
use crate::{ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken};
use minijinja::Environment;
#[test]
@ -784,58 +785,41 @@ mod tests {
#[test]
fn test_chat_template_invalid_with_guideline() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
let ct = ChatTemplate::new(
"{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n<end_of_turn>\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n".to_string(),
Some(TokenizerConfigToken::String("<s>".to_string())),
Some(TokenizerConfigToken::String("</s>".to_string())),
);
let source = "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n<end_of_turn>\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n";
// trim all the whitespace
let source = source
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");
let tmpl = env.template_from_str(&source);
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
TextMessage {
// convert TextMessage to Message
let msgs: Vec<Message> = vec![
Message {
name: None,
role: "user".to_string(),
content: "Hi!".to_string(),
content: MessageContent::SingleText(
"I'd like to show off how chat templating works!".to_string(),
),
},
TextMessage {
role: "user".to_string(),
content: "Hi again!".to_string(),
},
TextMessage {
Message {
name: None,
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
content: MessageContent::SingleText(
"I'm doing great. How can I help you today?".to_string(),
),
},
TextMessage {
Message {
name: None,
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
content: MessageContent::SingleText("Hello, how are you?".to_string()),
},
TextMessage {
role: "assistant".to_string(),
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
..Default::default()
};
];
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
let result = ct.apply(None, msgs, None);
match result {
Ok(_) => panic!("Should have failed since no guideline is provided"),
Err(e) => {
assert_eq!(
e.detail().unwrap(),
"Conversation roles must alternate user/assistant/user/assistant/..."
);
assert_eq!(e.to_string(), "Missing template vatiable: guideline")
}
}
}

View File

@ -830,7 +830,7 @@ mod tests {
.await
{
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
Ok((_s, 0, 10)) => (),
Ok((_s, _, 0, 10)) => (),
r => panic!("Unexpected not max new tokens: {r:?}"),
}
}