diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 685177c7..dce793f5 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -627,10 +627,11 @@ class CausalLM(Model): def batch_type(self) -> Type[CausalLMBatch]: return CausalLMBatch - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) + # This is not used anymore + # def decode(self, generated_ids: List[int]) -> str: + # return self.tokenizer.decode( + # generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + # ) def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5f558caa..77125f53 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -827,6 +827,7 @@ class FlashCausalLM(Model): aliases=None, # Used for Santacoder override of config num_kv_heads=None, + skip_special_tokens: bool = True, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 38695b19..5d16c364 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -668,10 +668,11 @@ class Seq2SeqLM(Model): def batch_type(self) -> Type[Seq2SeqLMBatch]: return Seq2SeqLMBatch - def decode(self, decoder_ids: List[int]) -> str: - return self.tokenizer.decode( - decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) + # Not used anymore + # def decode(self, decoder_ids: List[int]) -> str: + # return self.tokenizer.decode( + # decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + # ) def forward( self,