add clean_up_tokenization_spaces

This commit is contained in:
OlivierDehaene 2023-03-06 13:10:12 +01:00
parent 6a56f945c0
commit 56d23753bb
2 changed files with 7 additions and 3 deletions

View File

@ -36,6 +36,8 @@ class Model(ABC):
def decode_token(self, token_id: int) -> str: def decode_token(self, token_id: int) -> str:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers""" """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# append token to special decode token and decode both # append token to special decode token and decode both
result = self.tokenizer.decode([self.special_decode_token_id, token_id], skip_special_tokens=False) result = self.tokenizer.decode(
[self.special_decode_token_id, token_id], skip_special_tokens=False
)
# slice to remove special decode token # slice to remove special decode token
return result[self.special_decode_token_length:] return result[self.special_decode_token_length :]

View File

@ -342,7 +342,9 @@ class Seq2SeqLM(Model):
return Seq2SeqLMBatch return Seq2SeqLMBatch
def decode(self, decoder_ids: List[int]) -> str: def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) return self.tokenizer.decode(
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward( def forward(
self, self,