fix tests

This commit is contained in:
OlivierDehaene 2024-04-10 17:20:07 +02:00
parent 07a3050b20
commit 93e7ba54c0

View File

@ -48,13 +48,13 @@ pub struct HubModelInfo {
pub pipeline_tag: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct ChatTemplate {
name: String,
template: String,
}
#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum ChatTemplateVersions {
Single(String),
@ -990,7 +990,10 @@ mod tests {
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
// check that we successfully parsed the tokens
assert_eq!(config.chat_template, Some("test".to_string()));
assert_eq!(
config.chat_template,
Some(ChatTemplateVersions::Single("test".to_string()))
);
assert_eq!(
config.bos_token,
Some("<begin▁of▁sentence>".to_string())
@ -1022,7 +1025,10 @@ mod tests {
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
// check that we successfully parsed the tokens
assert_eq!(config.chat_template, Some("test".to_string()));
assert_eq!(
config.chat_template,
Some(ChatTemplateVersions::Single("test".to_string()))
);
assert_eq!(
config.bos_token,
Some("<begin▁of▁sentence>".to_string())