diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 27b65f43..ef4beee2 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -93,7 +93,8 @@ impl ChatTemplate { #[cfg(test)] mod tests { use crate::infer::chat_template::raise_exception; - use crate::{ChatTemplateInputs, TextMessage}; + use crate::infer::ChatTemplate; + use crate::{ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken}; use minijinja::Environment; #[test] @@ -784,58 +785,41 @@ mod tests { #[test] fn test_chat_template_invalid_with_guideline() { - let mut env = Environment::new(); - env.add_function("raise_exception", raise_exception); + let ct = ChatTemplate::new( + "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n".to_string(), + Some(TokenizerConfigToken::String("".to_string())), + Some(TokenizerConfigToken::String("".to_string())), + ); - let source = "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n"; + // convert TextMessage to Message + let msgs: Vec = vec![ + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "I'd like to show off how chat templating works!".to_string(), + ), + }, + Message { + name: None, + role: "assistant".to_string(), + content: MessageContent::SingleText( + "I'm doing great. How can I help you today?".to_string(), + ), + }, + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText("Hello, how are you?".to_string()), + }, + ]; - // trim all the whitespace - let source = source - .lines() - .map(|line| line.trim()) - .collect::>() - .join(""); - - let tmpl = env.template_from_str(&source); - - let chat_template_inputs = ChatTemplateInputs { - messages: vec![ - TextMessage { - role: "user".to_string(), - content: "Hi!".to_string(), - }, - TextMessage { - role: "user".to_string(), - content: "Hi again!".to_string(), - }, - TextMessage { - role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), - }, - TextMessage { - role: "user".to_string(), - content: "What is Deep Learning?".to_string(), - }, - TextMessage { - role: "assistant".to_string(), - content: "magic!".to_string(), - }, - ], - bos_token: Some("[BOS]"), - eos_token: Some("[EOS]"), - add_generation_prompt: true, - ..Default::default() - }; - - let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); + let result = ct.apply(None, msgs, None); match result { Ok(_) => panic!("Should have failed since no guideline is provided"), Err(e) => { - assert_eq!( - e.detail().unwrap(), - "Conversation roles must alternate user/assistant/user/assistant/..." - ); + assert_eq!(e.to_string(), "Missing template vatiable: guideline") } } } diff --git a/router/src/validation.rs b/router/src/validation.rs index 5011158a..0024723c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -830,7 +830,7 @@ mod tests { .await { // Err(ValidationError::MaxNewTokens(1, 10)) => (), - Ok((_s, 0, 10)) => (), + Ok((_s, _, 0, 10)) => (), r => panic!("Unexpected not max new tokens: {r:?}"), } }