From bc817953706b2e20543e5ea790a5f5c8a7586774 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 17 Jan 2024 18:34:25 -0500 Subject: [PATCH] fix: use generic raise_exception function to improve tests --- router/src/infer.rs | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index a61331d5..8a9875eb 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -47,6 +47,11 @@ struct Shared { batching_task: Notify, } +/// Raise a exception (custom function) used in the chat templates +fn raise_exception(err_text: String) -> Result { + Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) +} + impl Infer { #[allow(clippy::too_many_arguments)] pub(crate) fn new( @@ -87,9 +92,6 @@ impl Infer { let template = tokenizer_config.chat_template.map(|t| { let mut env = Box::new(Environment::new()); let template_str = t.into_boxed_str(); - fn raise_exception(err_text: String) -> Result { - Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) - } env.add_function("raise_exception", raise_exception); // leaking env and template_str as read-only, static resources for performance. Box::leak(env) @@ -727,6 +729,7 @@ impl InferError { // tests #[cfg(test)] mod tests { + use crate::infer::raise_exception; use crate::ChatTemplateInputs; use crate::Message; use minijinja::Environment; @@ -802,13 +805,6 @@ magic!"# #[test] fn test_chat_template_invalid_with_raise() { let mut env = Environment::new(); - - fn raise_exception(name: String) -> Result { - Err(minijinja::Error::new( - minijinja::ErrorKind::TemplateNotFound, - format!("Template not found: {}", name), - )) - } env.add_function("raise_exception", raise_exception); let source = r#" @@ -868,8 +864,8 @@ magic!"# Ok(_) => panic!("Should have failed"), Err(e) => { assert_eq!( - e.to_string(), - "template not found: Template not found: Conversation roles must alternate user/assistant/user/assistant/... (in :1)" + e.detail().unwrap(), + "Conversation roles must alternate user/assistant/user/assistant/..." ); } } @@ -878,13 +874,6 @@ magic!"# #[test] fn test_chat_template_valid_with_raise() { let mut env = Environment::new(); - - fn raise_exception(name: String) -> Result { - Err(minijinja::Error::new( - minijinja::ErrorKind::TemplateNotFound, - format!("Template not found: {}", name), - )) - } env.add_function("raise_exception", raise_exception); let source = r#"