mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: update create_post_processor logic for token type
This commit is contained in:
parent
c326ffdac0
commit
8885688630
@ -90,18 +90,6 @@ impl TokenizerConfigToken {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<TokenizerConfigToken> for String {
|
|
||||||
fn from(token: TokenizerConfigToken) -> Self {
|
|
||||||
token.as_str().to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<String> for TokenizerConfigToken {
|
|
||||||
fn from(s: String) -> Self {
|
|
||||||
TokenizerConfigToken::String(s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(tag = "processor_class")]
|
#[serde(tag = "processor_class")]
|
||||||
pub enum HubPreprocessorConfig {
|
pub enum HubPreprocessorConfig {
|
||||||
|
@ -553,11 +553,11 @@ pub fn create_post_processor(
|
|||||||
if add_bos_token {
|
if add_bos_token {
|
||||||
if let Some(bos) = bos_token {
|
if let Some(bos) = bos_token {
|
||||||
let bos_token_id = tokenizer
|
let bos_token_id = tokenizer
|
||||||
.token_to_id(bos)
|
.token_to_id(bos.as_str())
|
||||||
.expect("Should have found the bos token id");
|
.expect("Should have found the bos token id");
|
||||||
special_tokens.push((bos.clone(), bos_token_id));
|
special_tokens.push((bos.as_str(), bos_token_id));
|
||||||
single.push(format!("{}:0", bos));
|
single.push(format!("{}:0", bos.as_str()));
|
||||||
pair.push(format!("{}:0", bos));
|
pair.push(format!("{}:0", bos.as_str()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -567,17 +567,17 @@ pub fn create_post_processor(
|
|||||||
if add_eos_token {
|
if add_eos_token {
|
||||||
if let Some(eos) = eos_token {
|
if let Some(eos) = eos_token {
|
||||||
let eos_token_id = tokenizer
|
let eos_token_id = tokenizer
|
||||||
.token_to_id(eos)
|
.token_to_id(eos.as_str())
|
||||||
.expect("Should have found the eos token id");
|
.expect("Should have found the eos token id");
|
||||||
special_tokens.push((eos.clone(), eos_token_id));
|
special_tokens.push((eos.as_str(), eos_token_id));
|
||||||
single.push(format!("{}:0", eos));
|
single.push(format!("{}:0", eos.as_str()));
|
||||||
pair.push(format!("{}:0", eos));
|
pair.push(format!("{}:0", eos.as_str()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if add_bos_token {
|
if add_bos_token {
|
||||||
if let Some(bos) = bos_token {
|
if let Some(bos) = bos_token {
|
||||||
single.push(format!("{}:1", bos));
|
single.push(format!("{}:1", bos.as_str()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -585,7 +585,7 @@ pub fn create_post_processor(
|
|||||||
|
|
||||||
if add_eos_token {
|
if add_eos_token {
|
||||||
if let Some(eos) = eos_token {
|
if let Some(eos) = eos_token {
|
||||||
pair.push(format!("{}:1", eos));
|
pair.push(format!("{}:1", eos.as_str()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -611,14 +611,15 @@ enum RouterError {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use text_generation_router::TokenizerConfigToken;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_post_processor() {
|
fn test_create_post_processor() {
|
||||||
let tokenizer_config = HubTokenizerConfig {
|
let tokenizer_config = HubTokenizerConfig {
|
||||||
add_bos_token: None,
|
add_bos_token: None,
|
||||||
add_eos_token: None,
|
add_eos_token: None,
|
||||||
bos_token: Some("<s>".to_string()),
|
bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
|
||||||
eos_token: Some("</s>".to_string()),
|
eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
|
||||||
chat_template: None,
|
chat_template: None,
|
||||||
tokenizer_class: None,
|
tokenizer_class: None,
|
||||||
completion_template: None,
|
completion_template: None,
|
||||||
|
Loading…
Reference in New Issue
Block a user