diff --git a/router/src/lib.rs b/router/src/lib.rs index ea0179ea..a4ba8f28 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -48,13 +48,13 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, PartialEq)] pub struct ChatTemplate { name: String, template: String, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, PartialEq)] #[serde(untagged)] pub enum ChatTemplateVersions { Single(String), @@ -990,7 +990,10 @@ mod tests { let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap(); // check that we successfully parsed the tokens - assert_eq!(config.chat_template, Some("test".to_string())); + assert_eq!( + config.chat_template, + Some(ChatTemplateVersions::Single("test".to_string())) + ); assert_eq!( config.bos_token, Some("<|begin▁of▁sentence|>".to_string()) @@ -1022,7 +1025,10 @@ mod tests { let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap(); // check that we successfully parsed the tokens - assert_eq!(config.chat_template, Some("test".to_string())); + assert_eq!( + config.chat_template, + Some(ChatTemplateVersions::Single("test".to_string())) + ); assert_eq!( config.bos_token, Some("<|begin▁of▁sentence|>".to_string())