diff --git a/server/text_generation/models/model.py b/server/text_generation/models/model.py index eb39c099..09fa6a2a 100644 --- a/server/text_generation/models/model.py +++ b/server/text_generation/models/model.py @@ -36,6 +36,8 @@ class Model(ABC): def decode_token(self, token_id: int) -> str: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" # append token to special decode token and decode both - result = self.tokenizer.decode([self.special_decode_token_id, token_id], skip_special_tokens=False) + result = self.tokenizer.decode( + [self.special_decode_token_id, token_id], skip_special_tokens=False + ) # slice to remove special decode token - return result[self.special_decode_token_length:] + return result[self.special_decode_token_length :] diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index ce1f7993..4b88baec 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -342,7 +342,9 @@ class Seq2SeqLM(Model): return Seq2SeqLMBatch def decode(self, decoder_ids: List[int]) -> str: - return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) + return self.tokenizer.decode( + decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) def forward( self,