From 56d23753bb18ca7ed3d4ea5b4d18452a68e13dbd Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 6 Mar 2023 13:10:12 +0100 Subject: [PATCH] add clean_up_tokenization_spaces --- server/text_generation/models/model.py | 6 ++++-- server/text_generation/models/seq2seq_lm.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/server/text_generation/models/model.py b/server/text_generation/models/model.py index eb39c099..09fa6a2a 100644 --- a/server/text_generation/models/model.py +++ b/server/text_generation/models/model.py @@ -36,6 +36,8 @@ class Model(ABC): def decode_token(self, token_id: int) -> str: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" # append token to special decode token and decode both - result = self.tokenizer.decode([self.special_decode_token_id, token_id], skip_special_tokens=False) + result = self.tokenizer.decode( + [self.special_decode_token_id, token_id], skip_special_tokens=False + ) # slice to remove special decode token - return result[self.special_decode_token_length:] + return result[self.special_decode_token_length :] diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index ce1f7993..4b88baec 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -342,7 +342,9 @@ class Seq2SeqLM(Model): return Seq2SeqLMBatch def decode(self, decoder_ids: List[int]) -> str: - return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) + return self.tokenizer.decode( + decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) def forward( self,