From 8a16c485950e5d2b51b1b597a2fbd1d50d1e05fe Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 26 Sep 2023 16:35:53 +0200 Subject: [PATCH] Ignoring special tokens + updating 1 test case. --- server/tests/models/test_seq2seq_lm.py | 4 ++-- server/text_generation_server/models/causal_lm.py | 1 + server/text_generation_server/models/flash_causal_lm.py | 6 +----- server/text_generation_server/models/idefics_causal_lm.py | 1 + server/text_generation_server/models/model.py | 5 +++-- server/text_generation_server/models/seq2seq_lm.py | 1 + 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 299340f8..b2e83e97 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -168,7 +168,7 @@ def test_seq2seq_lm_generate_token_completion( assert next_batch is None assert len(generations) == 1 - assert generations[0].generated_text.text == "a few weeks" + assert generations[0].generated_text.text == " a few weeks" assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].generated_text.generated_tokens == 7 @@ -186,7 +186,7 @@ def test_seq2seq_lm_generate_token_completion_multi( assert next_batch is not None assert len(generations) == 2 - assert generations[1].generated_text.text == "a few " + assert generations[1].generated_text.text == " a few " assert ( generations[1].request_id == default_multi_requests_seq2seq_lm_batch.requests[1].id diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index df5be0b1..0c5a7f0c 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -645,6 +645,7 @@ class CausalLM(Model): all_input_ids[:, 0], prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True ) # Get seed if isinstance(next_token_chooser.choice, Sampling): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6945f44b..12d8efeb 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -793,11 +793,6 @@ class FlashCausalLM(Model): return int(num_blocks * BLOCK_SIZE) - def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - def forward( self, input_ids: torch.Tensor, @@ -1012,6 +1007,7 @@ class FlashCausalLM(Model): all_input_ids, prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True ) generated_text = GeneratedText( output_text, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 02dda681..30cc2299 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -727,6 +727,7 @@ class IdeficsCausalLM(Model): all_input_ids[:, 0], prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True ) # Get seed if isinstance(next_token_chooser.choice, Sampling): diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 806e9833..73329b24 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -64,16 +64,17 @@ class Model(ABC): all_input_ids: List[int], prefix_offset: int = 0, read_offset: int = 0, + skip_special_tokens: bool = False, ) -> Tuple[str, int, int]: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" # The prefix text is necessary only to defeat cleanup algorithms in the decode # which decide to add a space or not depending on the surrounding ids. prefix_text = self.tokenizer.decode( - all_input_ids[prefix_offset:read_offset], skip_special_tokens=False + all_input_ids[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens ) new_text = self.tokenizer.decode( - all_input_ids[prefix_offset:], skip_special_tokens=False + all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens ) if len(new_text) > len(prefix_text) and not new_text.endswith("�"): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index b8a0daf5..679b6fb4 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -714,6 +714,7 @@ class Seq2SeqLM(Model): all_decoder_input_ids, prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1, read_offset=len(all_decoder_input_ids) - decoder_input_length, + skip_special_tokens=True ) # Get seed