diff --git a/router/Cargo.toml b/router/Cargo.toml index 170debda..a3afcf35 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -44,7 +44,7 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } -minijinja = "1.0.10" +minijinja = "1.0.12" futures-util = "0.3.30" [build-dependencies] diff --git a/router/src/infer.rs b/router/src/infer.rs index 42405327..df17d907 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1050,4 +1050,131 @@ mod tests { let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n"); } + + #[test] + fn test_many_chat_templates() { + let example_chat = vec![ + ("user", "Hello, how are you?"), + ("assistant", "I'm doing great. How can I help you today?"), + ("user", "I'd like to show off how chat templating works!"), + ]; + + let example_chat_with_system = vec![( + "system", + "You are a friendly chatbot who always responds in the style of a pirate", + )] + .iter() + .chain(&example_chat) + .cloned() + .collect::>(); + + let test_default_templates = vec![( + /* name */ "_base", + /* chat_template */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", + /* messages */ example_chat.clone(), + /* add_generation_prompt */ false, + /* bos_token */ "", + /* eos_token */ "", + /* target */ "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", + ), + ( + /* name */ "blenderbot", + /* chat_template */ "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", + /* messages */ example_chat.clone(), + /* add_generation_prompt */ false, + /* bos_token */ "", + /* eos_token */ "", + /* target */ " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!", + ), + ( + /* name */ "blenderbot_small", + /* chat_template */ "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", + /* messages */ example_chat.clone(), + /* add_generation_prompt */ false, + /* bos_token */ "", + /* eos_token */ "", + /* target */ " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!", + ), + ( + /* name */ "bloom", + /* chat_template */ "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + /* messages */ example_chat.clone(), + /* add_generation_prompt */ false, + /* bos_token */ "", + /* eos_token */ "", + /* target */ "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!", + ), + ( + /* name */ "gpt_neox", + /* chat_template */ "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + /* messages */ example_chat.clone(), + /* add_generation_prompt */ false, + /* bos_token */ "", + /* eos_token */ "<|endoftext|>", + /* target */ "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", + ), + ( + /* name */ "gpt2", + /* chat_template */ "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + /* messages */ example_chat.clone(), + /* add_generation_prompt */ false, + /* bos_token */ "", + /* eos_token */ "<|endoftext|>", + /* target */ "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + ), + ( + /* name */ "llama", + // NOTE: the `.strip()` has been replaced with `| trim` in the following template + /* chat_template */ "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content | trim + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}", + /* messages */ example_chat_with_system.clone(), + /* add_generation_prompt */ true, + /* bos_token */ "", + /* eos_token */ "", + /* target */ "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + ), + ( + /* name */ "whisper", + /* chat_template */ "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + /* messages */ example_chat.clone(), + /* add_generation_prompt */ true, + /* bos_token */ "", + /* eos_token */ "<|endoftext|>", + /* target */ "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + ) + ]; + + for (_name, chat_template, messages, add_generation_prompt, bos_token, eos_token, target) in + test_default_templates + { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + // trim all the whitespace + let chat_template = chat_template + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&chat_template); + + let chat_template_inputs = ChatTemplateInputs { + messages: messages + .iter() + .map(|(role, content)| Message { + role: role.to_string(), + content: Some(content.to_string()), + name: None, + tool_calls: None, + }) + .collect(), + bos_token: Some(bos_token), + eos_token: Some(eos_token), + add_generation_prompt, + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + assert_eq!(result, target); + } + } }