mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: adjust when post_processor is overridden and improve create_post_processor
This commit is contained in:
parent
74535ce80f
commit
a921854d92
@ -309,9 +309,9 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
||||||
if let Some(tokenizer) = &mut tokenizer {
|
if let Some(tokenizer) = &mut tokenizer {
|
||||||
if let Some(class) = &tokenizer_config.tokenizer_class {
|
if let Some(class) = &tokenizer_config.tokenizer_class {
|
||||||
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast" {
|
if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast") && tokenizer.get_post_processor().is_none() {
|
||||||
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
||||||
if let Some(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
||||||
tokenizer.with_post_processor(post_processor);
|
tokenizer.with_post_processor(post_processor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -531,7 +531,7 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
|
|||||||
pub fn create_post_processor(
|
pub fn create_post_processor(
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
tokenizer_config: &HubTokenizerConfig,
|
tokenizer_config: &HubTokenizerConfig,
|
||||||
) -> Option<TemplateProcessing> {
|
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
|
||||||
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
|
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
|
||||||
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
|
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
|
||||||
|
|
||||||
@ -546,53 +546,56 @@ pub fn create_post_processor(
|
|||||||
panic!("add_eos_token = true but eos_token is None");
|
panic!("add_eos_token = true but eos_token is None");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut single = String::new();
|
let mut single = Vec::new();
|
||||||
let mut pair = String::new();
|
let mut pair = Vec::new();
|
||||||
let mut special_tokens = Vec::new();
|
let mut special_tokens = Vec::new();
|
||||||
|
|
||||||
if add_bos_token {
|
if add_bos_token {
|
||||||
let bos = bos_token.unwrap();
|
if let Some(bos) = bos_token {
|
||||||
let bos_token_id = tokenizer
|
let bos_token_id = tokenizer
|
||||||
.token_to_id(bos)
|
.token_to_id(bos)
|
||||||
.expect("Should have found the bos token id");
|
.expect("Should have found the bos token id");
|
||||||
special_tokens.push((bos.clone(), bos_token_id));
|
special_tokens.push((bos.clone(), bos_token_id));
|
||||||
single.push_str(&format!("{}:0 ", bos));
|
single.push(format!("{}:0", bos));
|
||||||
pair.push_str(&format!("{}:0 ", bos));
|
pair.push(format!("{}:0", bos));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
single.push_str("$A:0");
|
single.push("$A:0".to_string());
|
||||||
pair.push_str("$A:0");
|
pair.push("$A:0".to_string());
|
||||||
|
|
||||||
if add_eos_token {
|
if add_eos_token {
|
||||||
let eos = eos_token.unwrap();
|
if let Some(eos) = eos_token {
|
||||||
let eos_token_id = tokenizer
|
let eos_token_id = tokenizer
|
||||||
.token_to_id(eos)
|
.token_to_id(eos)
|
||||||
.expect("Should have found the eos token id");
|
.expect("Should have found the eos token id");
|
||||||
special_tokens.push((eos.clone(), eos_token_id));
|
special_tokens.push((eos.clone(), eos_token_id));
|
||||||
single.push_str(&format!(" {}:0", eos));
|
single.push(format!("{}:0", eos));
|
||||||
pair.push_str(&format!(" {}:0", eos));
|
pair.push(format!("{}:0", eos));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if add_bos_token {
|
if add_bos_token {
|
||||||
pair.push_str(&format!(" {}:1", bos_token.unwrap()));
|
if let Some(bos) = bos_token {
|
||||||
|
single.push(format!("{}:1", bos));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pair.push_str(" $B:1");
|
pair.push("$B:1".to_string());
|
||||||
|
|
||||||
if add_eos_token {
|
if add_eos_token {
|
||||||
pair.push_str(&format!(" {}:1", eos_token.unwrap()));
|
if let Some(eos) = eos_token {
|
||||||
|
pair.push(format!("{}:1", eos));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let post_processor = TemplateProcessing::builder()
|
let post_processor = TemplateProcessing::builder()
|
||||||
.try_single(single)
|
.try_single(single)?
|
||||||
.unwrap()
|
.try_pair(pair)?
|
||||||
.try_pair(pair)
|
|
||||||
.unwrap()
|
|
||||||
.special_tokens(special_tokens)
|
.special_tokens(special_tokens)
|
||||||
.build()
|
.build()?;
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Some(post_processor)
|
Ok(post_processor)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
@ -626,9 +629,9 @@ mod tests {
|
|||||||
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
|
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
|
||||||
|
|
||||||
let expected = TemplateProcessing::builder()
|
let expected = TemplateProcessing::builder()
|
||||||
.try_single("<s>:0 $A:0")
|
.try_single("<s>:0 $A:0 <s>:1")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.try_pair("<s>:0 $A:0 <s>:1 $B:1")
|
.try_pair("<s>:0 $A:0 $B:1")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.special_tokens(vec![("<s>".to_string(), 1)])
|
.special_tokens(vec![("<s>".to_string(), 1)])
|
||||||
.build()
|
.build()
|
||||||
|
Loading…
Reference in New Issue
Block a user