fix: improve missing template var test

This commit is contained in:
drbh 2024-08-12 14:35:42 +00:00
parent 2551456fff
commit 298efa41c5
2 changed files with 32 additions and 48 deletions

View File

@ -93,7 +93,8 @@ impl ChatTemplate {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::infer::chat_template::raise_exception; use crate::infer::chat_template::raise_exception;
use crate::{ChatTemplateInputs, TextMessage}; use crate::infer::ChatTemplate;
use crate::{ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken};
use minijinja::Environment; use minijinja::Environment;
#[test] #[test]
@ -784,58 +785,41 @@ mod tests {
#[test] #[test]
fn test_chat_template_invalid_with_guideline() { fn test_chat_template_invalid_with_guideline() {
let mut env = Environment::new(); let ct = ChatTemplate::new(
env.add_function("raise_exception", raise_exception); "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n<end_of_turn>\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n".to_string(),
Some(TokenizerConfigToken::String("<s>".to_string())),
Some(TokenizerConfigToken::String("</s>".to_string())),
);
let source = "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n<end_of_turn>\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n"; // convert TextMessage to Message
let msgs: Vec<Message> = vec![
// trim all the whitespace Message {
let source = source name: None,
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");
let tmpl = env.template_from_str(&source);
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: MessageContent::SingleText(
"I'd like to show off how chat templating works!".to_string(),
),
}, },
TextMessage { Message {
role: "user".to_string(), name: None,
content: "Hi again!".to_string(),
},
TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: MessageContent::SingleText(
"I'm doing great. How can I help you today?".to_string(),
),
}, },
TextMessage { Message {
name: None,
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: MessageContent::SingleText("Hello, how are you?".to_string()),
}, },
TextMessage { ];
role: "assistant".to_string(),
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
..Default::default()
};
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); let result = ct.apply(None, msgs, None);
match result { match result {
Ok(_) => panic!("Should have failed since no guideline is provided"), Ok(_) => panic!("Should have failed since no guideline is provided"),
Err(e) => { Err(e) => {
assert_eq!( assert_eq!(e.to_string(), "Missing template vatiable: guideline")
e.detail().unwrap(),
"Conversation roles must alternate user/assistant/user/assistant/..."
);
} }
} }
} }

View File

@ -830,7 +830,7 @@ mod tests {
.await .await
{ {
// Err(ValidationError::MaxNewTokens(1, 10)) => (), // Err(ValidationError::MaxNewTokens(1, 10)) => (),
Ok((_s, 0, 10)) => (), Ok((_s, _, 0, 10)) => (),
r => panic!("Unexpected not max new tokens: {r:?}"), r => panic!("Unexpected not max new tokens: {r:?}"),
} }
} }