fix: use generic raise_exception function to improve tests

This commit is contained in:
drbh 2024-01-17 18:34:25 -05:00
parent f378c60517
commit bc81795370

View File

@ -47,6 +47,11 @@ struct Shared {
batching_task: Notify,
}
/// Raise a exception (custom function) used in the chat templates
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}
impl Infer {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
@ -87,9 +92,6 @@ impl Infer {
let template = tokenizer_config.chat_template.map(|t| {
let mut env = Box::new(Environment::new());
let template_str = t.into_boxed_str();
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}
env.add_function("raise_exception", raise_exception);
// leaking env and template_str as read-only, static resources for performance.
Box::leak(env)
@ -727,6 +729,7 @@ impl InferError {
// tests
#[cfg(test)]
mod tests {
use crate::infer::raise_exception;
use crate::ChatTemplateInputs;
use crate::Message;
use minijinja::Environment;
@ -802,13 +805,6 @@ magic!"#
#[test]
fn test_chat_template_invalid_with_raise() {
let mut env = Environment::new();
fn raise_exception(name: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(
minijinja::ErrorKind::TemplateNotFound,
format!("Template not found: {}", name),
))
}
env.add_function("raise_exception", raise_exception);
let source = r#"
@ -868,8 +864,8 @@ magic!"#
Ok(_) => panic!("Should have failed"),
Err(e) => {
assert_eq!(
e.to_string(),
"template not found: Template not found: Conversation roles must alternate user/assistant/user/assistant/... (in <string>:1)"
e.detail().unwrap(),
"Conversation roles must alternate user/assistant/user/assistant/..."
);
}
}
@ -878,13 +874,6 @@ magic!"#
#[test]
fn test_chat_template_valid_with_raise() {
let mut env = Environment::new();
fn raise_exception(name: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(
minijinja::ErrorKind::TemplateNotFound,
format!("Template not found: {}", name),
))
}
env.add_function("raise_exception", raise_exception);
let source = r#"