mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: update create_post_processor logic for token type
This commit is contained in:
parent
c326ffdac0
commit
8885688630
@ -90,18 +90,6 @@ impl TokenizerConfigToken {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TokenizerConfigToken> for String {
|
||||
fn from(token: TokenizerConfigToken) -> Self {
|
||||
token.as_str().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for TokenizerConfigToken {
|
||||
fn from(s: String) -> Self {
|
||||
TokenizerConfigToken::String(s)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "processor_class")]
|
||||
pub enum HubPreprocessorConfig {
|
||||
|
@ -553,11 +553,11 @@ pub fn create_post_processor(
|
||||
if add_bos_token {
|
||||
if let Some(bos) = bos_token {
|
||||
let bos_token_id = tokenizer
|
||||
.token_to_id(bos)
|
||||
.token_to_id(bos.as_str())
|
||||
.expect("Should have found the bos token id");
|
||||
special_tokens.push((bos.clone(), bos_token_id));
|
||||
single.push(format!("{}:0", bos));
|
||||
pair.push(format!("{}:0", bos));
|
||||
special_tokens.push((bos.as_str(), bos_token_id));
|
||||
single.push(format!("{}:0", bos.as_str()));
|
||||
pair.push(format!("{}:0", bos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -567,17 +567,17 @@ pub fn create_post_processor(
|
||||
if add_eos_token {
|
||||
if let Some(eos) = eos_token {
|
||||
let eos_token_id = tokenizer
|
||||
.token_to_id(eos)
|
||||
.token_to_id(eos.as_str())
|
||||
.expect("Should have found the eos token id");
|
||||
special_tokens.push((eos.clone(), eos_token_id));
|
||||
single.push(format!("{}:0", eos));
|
||||
pair.push(format!("{}:0", eos));
|
||||
special_tokens.push((eos.as_str(), eos_token_id));
|
||||
single.push(format!("{}:0", eos.as_str()));
|
||||
pair.push(format!("{}:0", eos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
if add_bos_token {
|
||||
if let Some(bos) = bos_token {
|
||||
single.push(format!("{}:1", bos));
|
||||
single.push(format!("{}:1", bos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -585,7 +585,7 @@ pub fn create_post_processor(
|
||||
|
||||
if add_eos_token {
|
||||
if let Some(eos) = eos_token {
|
||||
pair.push(format!("{}:1", eos));
|
||||
pair.push(format!("{}:1", eos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -611,14 +611,15 @@ enum RouterError {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use text_generation_router::TokenizerConfigToken;
|
||||
|
||||
#[test]
|
||||
fn test_create_post_processor() {
|
||||
let tokenizer_config = HubTokenizerConfig {
|
||||
add_bos_token: None,
|
||||
add_eos_token: None,
|
||||
bos_token: Some("<s>".to_string()),
|
||||
eos_token: Some("</s>".to_string()),
|
||||
bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
|
||||
eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
|
||||
chat_template: None,
|
||||
tokenizer_class: None,
|
||||
completion_template: None,
|
||||
|
Loading…
Reference in New Issue
Block a user