mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: bump minijina and add test for core templates
This commit is contained in:
parent
0d72af5ab0
commit
a50447dc72
@ -44,7 +44,7 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
|
|||||||
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
||||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||||
minijinja = "1.0.10"
|
minijinja = "1.0.12"
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -1050,4 +1050,131 @@ mod tests {
|
|||||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n");
|
assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_many_chat_templates() {
|
||||||
|
let example_chat = vec![
|
||||||
|
("user", "Hello, how are you?"),
|
||||||
|
("assistant", "I'm doing great. How can I help you today?"),
|
||||||
|
("user", "I'd like to show off how chat templating works!"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let example_chat_with_system = vec![(
|
||||||
|
"system",
|
||||||
|
"You are a friendly chatbot who always responds in the style of a pirate",
|
||||||
|
)]
|
||||||
|
.iter()
|
||||||
|
.chain(&example_chat)
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let test_default_templates = vec![(
|
||||||
|
/* name */ "_base",
|
||||||
|
/* chat_template */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
||||||
|
/* messages */ example_chat.clone(),
|
||||||
|
/* add_generation_prompt */ false,
|
||||||
|
/* bos_token */ "",
|
||||||
|
/* eos_token */ "",
|
||||||
|
/* target */ "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
/* name */ "blenderbot",
|
||||||
|
/* chat_template */ "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
|
||||||
|
/* messages */ example_chat.clone(),
|
||||||
|
/* add_generation_prompt */ false,
|
||||||
|
/* bos_token */ "",
|
||||||
|
/* eos_token */ "</s>",
|
||||||
|
/* target */ " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
/* name */ "blenderbot_small",
|
||||||
|
/* chat_template */ "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
|
||||||
|
/* messages */ example_chat.clone(),
|
||||||
|
/* add_generation_prompt */ false,
|
||||||
|
/* bos_token */ "",
|
||||||
|
/* eos_token */ "</s>",
|
||||||
|
/* target */ " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
/* name */ "bloom",
|
||||||
|
/* chat_template */ "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||||
|
/* messages */ example_chat.clone(),
|
||||||
|
/* add_generation_prompt */ false,
|
||||||
|
/* bos_token */ "",
|
||||||
|
/* eos_token */ "</s>",
|
||||||
|
/* target */ "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
/* name */ "gpt_neox",
|
||||||
|
/* chat_template */ "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||||
|
/* messages */ example_chat.clone(),
|
||||||
|
/* add_generation_prompt */ false,
|
||||||
|
/* bos_token */ "",
|
||||||
|
/* eos_token */ "<|endoftext|>",
|
||||||
|
/* target */ "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
/* name */ "gpt2",
|
||||||
|
/* chat_template */ "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||||
|
/* messages */ example_chat.clone(),
|
||||||
|
/* add_generation_prompt */ false,
|
||||||
|
/* bos_token */ "",
|
||||||
|
/* eos_token */ "<|endoftext|>",
|
||||||
|
/* target */ "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
|
||||||
|
),
|
||||||
|
(
|
||||||
|
/* name */ "llama",
|
||||||
|
// NOTE: the `.strip()` has been replaced with `| trim` in the following template
|
||||||
|
/* chat_template */ "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content | trim + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
|
||||||
|
/* messages */ example_chat_with_system.clone(),
|
||||||
|
/* add_generation_prompt */ true,
|
||||||
|
/* bos_token */ "<s>",
|
||||||
|
/* eos_token */ "</s>",
|
||||||
|
/* target */ "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]"
|
||||||
|
),
|
||||||
|
(
|
||||||
|
/* name */ "whisper",
|
||||||
|
/* chat_template */ "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||||
|
/* messages */ example_chat.clone(),
|
||||||
|
/* add_generation_prompt */ true,
|
||||||
|
/* bos_token */ "",
|
||||||
|
/* eos_token */ "<|endoftext|>",
|
||||||
|
/* target */ "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
|
||||||
|
)
|
||||||
|
];
|
||||||
|
|
||||||
|
for (_name, chat_template, messages, add_generation_prompt, bos_token, eos_token, target) in
|
||||||
|
test_default_templates
|
||||||
|
{
|
||||||
|
let mut env = Environment::new();
|
||||||
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
|
||||||
|
// trim all the whitespace
|
||||||
|
let chat_template = chat_template
|
||||||
|
.lines()
|
||||||
|
.map(|line| line.trim())
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.join("");
|
||||||
|
|
||||||
|
let tmpl = env.template_from_str(&chat_template);
|
||||||
|
|
||||||
|
let chat_template_inputs = ChatTemplateInputs {
|
||||||
|
messages: messages
|
||||||
|
.iter()
|
||||||
|
.map(|(role, content)| Message {
|
||||||
|
role: role.to_string(),
|
||||||
|
content: Some(content.to_string()),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
bos_token: Some(bos_token),
|
||||||
|
eos_token: Some(eos_token),
|
||||||
|
add_generation_prompt,
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
|
assert_eq!(result, target);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user