diff --git a/server/tests/models/test_model.py b/server/tests/models/test_model.py index fbcf873d..32bcd45f 100644 --- a/server/tests/models/test_model.py +++ b/server/tests/models/test_model.py @@ -6,8 +6,7 @@ from transformers import AutoTokenizer from text_generation_server.models import Model -@pytest.mark.private -def test_decode_streaming(): +def get_test_model(): class TestModel(Model): def batch_type(self): raise NotImplementedError @@ -20,7 +19,34 @@ def test_decode_streaming(): model = TestModel( torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu") ) + return model + +@pytest.mark.private +def test_decode_streaming_english_spaces(): + model = get_test_model() + truth = "Hello here, this is a simple test" + all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243] + assert ( + all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"] + ) + + decoded_text = "" + offset = 0 + token_offset = 0 + for i in range(len(all_input_ids)): + text, offset, token_offset = model.decode_token( + all_input_ids[: i + 1], offset, token_offset + ) + decoded_text += text + + assert decoded_text == truth + + +@pytest.mark.private +def test_decode_streaming_chinese_utf8(): + model = get_test_model() + truth = "我很感谢你的热情" all_input_ids = [ 30672, 232, @@ -40,11 +66,9 @@ def test_decode_streaming(): 30993, ] - truth = "我很感谢你的热情" - decoded_text = "" - offset = None - token_offset = None + offset = 0 + token_offset = 0 for i in range(len(all_input_ids)): text, offset, token_offset = model.decode_token( all_input_ids[: i + 1], offset, token_offset diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index f19fecb8..657e4821 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -56,52 +56,25 @@ class Model(ABC): def decode_token( self, all_input_ids: List[int], - offset: Optional[int] = None, - token_offset: Optional[int] = None, - ) -> Tuple[str, Optional[int], Optional[int]]: + prefix_offset: int = 0, + read_offset: int = 0, + ) -> Tuple[str, int, int]: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" - if all_input_ids[-1] in self.all_special_ids: - return ( - self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False), - None, - None, - ) - if token_offset is None: - token_offset = max(len(all_input_ids) - self.decode_buffer, 0) - # left token buffer - if self.decode_buffer > 1: - # Decode token_offset token minus last one and token_offset tokens - raw_texts = self.tokenizer.batch_decode( - [all_input_ids[token_offset:-1], all_input_ids[token_offset:]], - skip_special_tokens=False, - ) + # Compatibility layer for old None values. + if prefix_offset is None: + prefix_offset = 0 + if read_offset is None: + read_offset = 0 - # default offset is only the last token - offset = len(raw_texts[0]) - sequence_text = raw_texts[1] - else: - # Only decode the last token without using a token buffer - sequence_text = self.tokenizer.decode( - all_input_ids[-1], skip_special_tokens=False - ) - # no offset in this case - offset = 0 + prefix_text = self.tokenizer.decode(all_input_ids[prefix_offset:read_offset]) + new_text = self.tokenizer.decode(all_input_ids[prefix_offset:]) + + if len(new_text) > len(prefix_text) and "�" not in new_text: + new_text = new_text[len(prefix_text) :] + return new_text, read_offset, len(all_input_ids) else: - assert offset is not None - sequence_text = self.tokenizer.decode( - all_input_ids[token_offset:], - skip_special_tokens=False, - ) - - # get text - token_text = sequence_text[offset:] - - # if text is utf-8 - if token_text and token_text[-1] != "�": - return token_text, None, None - else: - return "", offset, token_offset + return "", prefix_offset, read_offset def check_initialized(self): uninitialized_parameters = []