mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: use generic raise_exception function to improve tests
This commit is contained in:
parent
f378c60517
commit
bc81795370
@ -47,6 +47,11 @@ struct Shared {
|
|||||||
batching_task: Notify,
|
batching_task: Notify,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Raise a exception (custom function) used in the chat templates
|
||||||
|
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||||
|
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||||
|
}
|
||||||
|
|
||||||
impl Infer {
|
impl Infer {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
@ -87,9 +92,6 @@ impl Infer {
|
|||||||
let template = tokenizer_config.chat_template.map(|t| {
|
let template = tokenizer_config.chat_template.map(|t| {
|
||||||
let mut env = Box::new(Environment::new());
|
let mut env = Box::new(Environment::new());
|
||||||
let template_str = t.into_boxed_str();
|
let template_str = t.into_boxed_str();
|
||||||
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
|
||||||
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
|
||||||
}
|
|
||||||
env.add_function("raise_exception", raise_exception);
|
env.add_function("raise_exception", raise_exception);
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
Box::leak(env)
|
Box::leak(env)
|
||||||
@ -727,6 +729,7 @@ impl InferError {
|
|||||||
// tests
|
// tests
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use crate::infer::raise_exception;
|
||||||
use crate::ChatTemplateInputs;
|
use crate::ChatTemplateInputs;
|
||||||
use crate::Message;
|
use crate::Message;
|
||||||
use minijinja::Environment;
|
use minijinja::Environment;
|
||||||
@ -802,13 +805,6 @@ magic!"#
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_chat_template_invalid_with_raise() {
|
fn test_chat_template_invalid_with_raise() {
|
||||||
let mut env = Environment::new();
|
let mut env = Environment::new();
|
||||||
|
|
||||||
fn raise_exception(name: String) -> Result<String, minijinja::Error> {
|
|
||||||
Err(minijinja::Error::new(
|
|
||||||
minijinja::ErrorKind::TemplateNotFound,
|
|
||||||
format!("Template not found: {}", name),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
env.add_function("raise_exception", raise_exception);
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
|
||||||
let source = r#"
|
let source = r#"
|
||||||
@ -868,8 +864,8 @@ magic!"#
|
|||||||
Ok(_) => panic!("Should have failed"),
|
Ok(_) => panic!("Should have failed"),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
e.to_string(),
|
e.detail().unwrap(),
|
||||||
"template not found: Template not found: Conversation roles must alternate user/assistant/user/assistant/... (in <string>:1)"
|
"Conversation roles must alternate user/assistant/user/assistant/..."
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -878,13 +874,6 @@ magic!"#
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_chat_template_valid_with_raise() {
|
fn test_chat_template_valid_with_raise() {
|
||||||
let mut env = Environment::new();
|
let mut env = Environment::new();
|
||||||
|
|
||||||
fn raise_exception(name: String) -> Result<String, minijinja::Error> {
|
|
||||||
Err(minijinja::Error::new(
|
|
||||||
minijinja::ErrorKind::TemplateNotFound,
|
|
||||||
format!("Template not found: {}", name),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
env.add_function("raise_exception", raise_exception);
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
|
||||||
let source = r#"
|
let source = r#"
|
||||||
|
Loading…
Reference in New Issue
Block a user