feat: bump minijina and add test for core templates

This commit is contained in:
drbh 2024-03-06 15:21:41 +00:00
parent 0d72af5ab0
commit a50447dc72
2 changed files with 128 additions and 1 deletions

View File

@ -44,7 +44,7 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] } utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true } ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = "1.0.10" minijinja = "1.0.12"
futures-util = "0.3.30" futures-util = "0.3.30"
[build-dependencies] [build-dependencies]

View File

@ -1050,4 +1050,131 @@ mod tests {
let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n"); assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n");
} }
#[test]
fn test_many_chat_templates() {
let example_chat = vec![
("user", "Hello, how are you?"),
("assistant", "I'm doing great. How can I help you today?"),
("user", "I'd like to show off how chat templating works!"),
];
let example_chat_with_system = vec![(
"system",
"You are a friendly chatbot who always responds in the style of a pirate",
)]
.iter()
.chain(&example_chat)
.cloned()
.collect::<Vec<_>>();
let test_default_templates = vec![(
/* name */ "_base",
/* chat_template */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
/* messages */ example_chat.clone(),
/* add_generation_prompt */ false,
/* bos_token */ "",
/* eos_token */ "",
/* target */ "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
),
(
/* name */ "blenderbot",
/* chat_template */ "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
/* messages */ example_chat.clone(),
/* add_generation_prompt */ false,
/* bos_token */ "",
/* eos_token */ "</s>",
/* target */ " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>",
),
(
/* name */ "blenderbot_small",
/* chat_template */ "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
/* messages */ example_chat.clone(),
/* add_generation_prompt */ false,
/* bos_token */ "",
/* eos_token */ "</s>",
/* target */ " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>",
),
(
/* name */ "bloom",
/* chat_template */ "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
/* messages */ example_chat.clone(),
/* add_generation_prompt */ false,
/* bos_token */ "",
/* eos_token */ "</s>",
/* target */ "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
),
(
/* name */ "gpt_neox",
/* chat_template */ "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
/* messages */ example_chat.clone(),
/* add_generation_prompt */ false,
/* bos_token */ "",
/* eos_token */ "<|endoftext|>",
/* target */ "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
),
(
/* name */ "gpt2",
/* chat_template */ "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
/* messages */ example_chat.clone(),
/* add_generation_prompt */ false,
/* bos_token */ "",
/* eos_token */ "<|endoftext|>",
/* target */ "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
),
(
/* name */ "llama",
// NOTE: the `.strip()` has been replaced with `| trim` in the following template
/* chat_template */ "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content | trim + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
/* messages */ example_chat_with_system.clone(),
/* add_generation_prompt */ true,
/* bos_token */ "<s>",
/* eos_token */ "</s>",
/* target */ "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]"
),
(
/* name */ "whisper",
/* chat_template */ "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
/* messages */ example_chat.clone(),
/* add_generation_prompt */ true,
/* bos_token */ "",
/* eos_token */ "<|endoftext|>",
/* target */ "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
)
];
for (_name, chat_template, messages, add_generation_prompt, bos_token, eos_token, target) in
test_default_templates
{
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
// trim all the whitespace
let chat_template = chat_template
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");
let tmpl = env.template_from_str(&chat_template);
let chat_template_inputs = ChatTemplateInputs {
messages: messages
.iter()
.map(|(role, content)| Message {
role: role.to_string(),
content: Some(content.to_string()),
name: None,
tool_calls: None,
})
.collect(),
bos_token: Some(bos_token),
eos_token: Some(eos_token),
add_generation_prompt,
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(result, target);
}
}
} }