diff --git a/router/src/infer.rs b/router/src/infer.rs index 637d208d..4da0da0a 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -807,22 +807,14 @@ mod tests { ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), - add_generation_prompt: false, + add_generation_prompt: true, }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); assert_eq!( result, - r#"### User: -Hi! - -### Assistant: -Hello how can I help?### User: -What is Deep Learning? - -### Assistant: -magic!"# + "### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n" ); } @@ -880,7 +872,7 @@ magic!"# ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), - add_generation_prompt: false, + add_generation_prompt: true, }; let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); @@ -946,10 +938,60 @@ magic!"# ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), - add_generation_prompt: false, + add_generation_prompt: true, }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]"); } + + #[test] + fn test_chat_template_valid_with_add_generation_prompt() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + let source = r#" + {% for message in messages %} + {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}} + {% endfor %} + {% if add_generation_prompt %} + {{ '<|im_start|>assistant\n' }} + {% endif %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + Message { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + Message { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n"); + } }