fix: update create_post_processor logic for token type

This commit is contained in:
drbh 2024-06-28 15:07:50 +00:00
parent c326ffdac0
commit 8885688630
2 changed files with 13 additions and 24 deletions

View File

@ -90,18 +90,6 @@ impl TokenizerConfigToken {
} }
} }
impl From<TokenizerConfigToken> for String {
fn from(token: TokenizerConfigToken) -> Self {
token.as_str().to_string()
}
}
impl From<String> for TokenizerConfigToken {
fn from(s: String) -> Self {
TokenizerConfigToken::String(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "processor_class")] #[serde(tag = "processor_class")]
pub enum HubPreprocessorConfig { pub enum HubPreprocessorConfig {

View File

@ -553,11 +553,11 @@ pub fn create_post_processor(
if add_bos_token { if add_bos_token {
if let Some(bos) = bos_token { if let Some(bos) = bos_token {
let bos_token_id = tokenizer let bos_token_id = tokenizer
.token_to_id(bos) .token_to_id(bos.as_str())
.expect("Should have found the bos token id"); .expect("Should have found the bos token id");
special_tokens.push((bos.clone(), bos_token_id)); special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos)); single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos)); pair.push(format!("{}:0", bos.as_str()));
} }
} }
@ -567,17 +567,17 @@ pub fn create_post_processor(
if add_eos_token { if add_eos_token {
if let Some(eos) = eos_token { if let Some(eos) = eos_token {
let eos_token_id = tokenizer let eos_token_id = tokenizer
.token_to_id(eos) .token_to_id(eos.as_str())
.expect("Should have found the eos token id"); .expect("Should have found the eos token id");
special_tokens.push((eos.clone(), eos_token_id)); special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos)); single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos)); pair.push(format!("{}:0", eos.as_str()));
} }
} }
if add_bos_token { if add_bos_token {
if let Some(bos) = bos_token { if let Some(bos) = bos_token {
single.push(format!("{}:1", bos)); single.push(format!("{}:1", bos.as_str()));
} }
} }
@ -585,7 +585,7 @@ pub fn create_post_processor(
if add_eos_token { if add_eos_token {
if let Some(eos) = eos_token { if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos)); pair.push(format!("{}:1", eos.as_str()));
} }
} }
@ -611,14 +611,15 @@ enum RouterError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use text_generation_router::TokenizerConfigToken;
#[test] #[test]
fn test_create_post_processor() { fn test_create_post_processor() {
let tokenizer_config = HubTokenizerConfig { let tokenizer_config = HubTokenizerConfig {
add_bos_token: None, add_bos_token: None,
add_eos_token: None, add_eos_token: None,
bos_token: Some("<s>".to_string()), bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
eos_token: Some("</s>".to_string()), eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
chat_template: None, chat_template: None,
tokenizer_class: None, tokenizer_class: None,
completion_template: None, completion_template: None,