diff --git a/router/src/lib.rs b/router/src/lib.rs index d5e551a0..9ecfa051 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -90,18 +90,6 @@ impl TokenizerConfigToken { } } -impl From for String { - fn from(token: TokenizerConfigToken) -> Self { - token.as_str().to_string() - } -} - -impl From for TokenizerConfigToken { - fn from(s: String) -> Self { - TokenizerConfigToken::String(s) - } -} - #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "processor_class")] pub enum HubPreprocessorConfig { diff --git a/router/src/main.rs b/router/src/main.rs index 1e8093d8..9a281556 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -553,11 +553,11 @@ pub fn create_post_processor( if add_bos_token { if let Some(bos) = bos_token { let bos_token_id = tokenizer - .token_to_id(bos) + .token_to_id(bos.as_str()) .expect("Should have found the bos token id"); - special_tokens.push((bos.clone(), bos_token_id)); - single.push(format!("{}:0", bos)); - pair.push(format!("{}:0", bos)); + special_tokens.push((bos.as_str(), bos_token_id)); + single.push(format!("{}:0", bos.as_str())); + pair.push(format!("{}:0", bos.as_str())); } } @@ -567,17 +567,17 @@ pub fn create_post_processor( if add_eos_token { if let Some(eos) = eos_token { let eos_token_id = tokenizer - .token_to_id(eos) + .token_to_id(eos.as_str()) .expect("Should have found the eos token id"); - special_tokens.push((eos.clone(), eos_token_id)); - single.push(format!("{}:0", eos)); - pair.push(format!("{}:0", eos)); + special_tokens.push((eos.as_str(), eos_token_id)); + single.push(format!("{}:0", eos.as_str())); + pair.push(format!("{}:0", eos.as_str())); } } if add_bos_token { if let Some(bos) = bos_token { - single.push(format!("{}:1", bos)); + single.push(format!("{}:1", bos.as_str())); } } @@ -585,7 +585,7 @@ pub fn create_post_processor( if add_eos_token { if let Some(eos) = eos_token { - pair.push(format!("{}:1", eos)); + pair.push(format!("{}:1", eos.as_str())); } } @@ -611,14 +611,15 @@ enum RouterError { #[cfg(test)] mod tests { use super::*; + use text_generation_router::TokenizerConfigToken; #[test] fn test_create_post_processor() { let tokenizer_config = HubTokenizerConfig { add_bos_token: None, add_eos_token: None, - bos_token: Some("".to_string()), - eos_token: Some("".to_string()), + bos_token: Some(TokenizerConfigToken::String("".to_string())), + eos_token: Some(TokenizerConfigToken::String("".to_string())), chat_template: None, tokenizer_class: None, completion_template: None,