diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cec9ae55..df5be0b1 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -641,8 +641,10 @@ class CausalLM(Model): if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, ) # Get seed if isinstance(next_token_chooser.choice, Sampling): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d6af07f4..6945f44b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1008,8 +1008,10 @@ class FlashCausalLM(Model): if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :] + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, ) generated_text = GeneratedText( output_text, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index f4177145..02dda681 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -611,11 +611,6 @@ class IdeficsCausalLM(Model): def batch_type(self) -> Type[IdeficsCausalLMBatch]: return IdeficsCausalLMBatch - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - def forward( self, input_ids, @@ -728,8 +723,10 @@ class IdeficsCausalLM(Model): if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, ) # Get seed if isinstance(next_token_chooser.choice, Sampling): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 1a7911ac..b8a0daf5 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -710,8 +710,10 @@ class Seq2SeqLM(Model): if stop: # Slice with decoder_input_length to remove padding # Decode all tokens - output_text = self.decode( - all_decoder_input_ids[-decoder_input_length:] + output_text, _, _ = self.decode_token( + all_decoder_input_ids, + prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1, + read_offset=len(all_decoder_input_ids) - decoder_input_length, ) # Get seed