diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 4e338263..a543e073 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -641,8 +641,8 @@ class CausalLM(Model): if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + output_text = self.decode_generated_tokens( + all_input_ids[:, 0], stopping_criteria.current_tokens ) # Get seed if isinstance(next_token_chooser.choice, Sampling): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d6af07f4..0cfbe446 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1008,9 +1008,11 @@ class FlashCausalLM(Model): if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :] + output_text = self.decode_generated_tokens( + all_input_ids, + stopping_criteria.current_tokens, ) + generated_text = GeneratedText( output_text, stopping_criteria.current_tokens, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 2dac87bc..a8481483 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -728,8 +728,8 @@ class IdeficsCausalLM(Model): if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + output_text = self.decode_generated_tokens( + all_input_ids[:, 0], stopping_criteria.current_tokens ) # Get seed if isinstance(next_token_chooser.choice, Sampling): diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 806e9833..ae574009 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -85,6 +85,21 @@ class Model(ABC): return new_text, read_offset, len(all_input_ids) else: return "", prefix_offset, read_offset + + def decode_generated_tokens( + self, + all_input_ids: List[int], + num_tokens: int = 0, + ) -> str: + # Like in `decode_token()`, the prefix text is necessary only to defeat cleanup algorithms in the decode. + prefix_text = self.tokenizer.decode( + all_input_ids[-num_tokens-1:-num_tokens], skip_special_tokens=False + ) + new_text = self.tokenizer.decode( + all_input_ids[-num_tokens-1:], skip_special_tokens=False + ) + new_text = new_text[len(prefix_text):] + return new_text def check_initialized(self): uninitialized_parameters = []