Simplifying streaming decode.

This commit is contained in:
Nicolas Patry 2023-05-16 11:32:25 +02:00 committed by OlivierDehaene
parent d2a99b4294
commit 1aa31bb5cc
2 changed files with 45 additions and 48 deletions

View File

@ -6,8 +6,7 @@ from transformers import AutoTokenizer
from text_generation_server.models import Model
@pytest.mark.private
def test_decode_streaming():
def get_test_model():
class TestModel(Model):
def batch_type(self):
raise NotImplementedError
@ -20,7 +19,34 @@ def test_decode_streaming():
model = TestModel(
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
)
return model
@pytest.mark.private
def test_decode_streaming_english_spaces():
model = get_test_model()
truth = "Hello here, this is a simple test"
all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
assert (
all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
)
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth
@pytest.mark.private
def test_decode_streaming_chinese_utf8():
model = get_test_model()
truth = "我很感谢你的热情"
all_input_ids = [
30672,
232,
@ -40,11 +66,9 @@ def test_decode_streaming():
30993,
]
truth = "我很感谢你的热情"
decoded_text = ""
offset = None
token_offset = None
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset

View File

@ -56,52 +56,25 @@ class Model(ABC):
def decode_token(
self,
all_input_ids: List[int],
offset: Optional[int] = None,
token_offset: Optional[int] = None,
) -> Tuple[str, Optional[int], Optional[int]]:
prefix_offset: int = 0,
read_offset: int = 0,
) -> Tuple[str, int, int]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
if all_input_ids[-1] in self.all_special_ids:
return (
self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
None,
None,
)
if token_offset is None:
token_offset = max(len(all_input_ids) - self.decode_buffer, 0)
# left token buffer
if self.decode_buffer > 1:
# Decode token_offset token minus last one and token_offset tokens
raw_texts = self.tokenizer.batch_decode(
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
skip_special_tokens=False,
)
# Compatibility layer for old None values.
if prefix_offset is None:
prefix_offset = 0
if read_offset is None:
read_offset = 0
# default offset is only the last token
offset = len(raw_texts[0])
sequence_text = raw_texts[1]
else:
# Only decode the last token without using a token buffer
sequence_text = self.tokenizer.decode(
all_input_ids[-1], skip_special_tokens=False
)
# no offset in this case
offset = 0
prefix_text = self.tokenizer.decode(all_input_ids[prefix_offset:read_offset])
new_text = self.tokenizer.decode(all_input_ids[prefix_offset:])
if len(new_text) > len(prefix_text) and "<EFBFBD>" not in new_text:
new_text = new_text[len(prefix_text) :]
return new_text, read_offset, len(all_input_ids)
else:
assert offset is not None
sequence_text = self.tokenizer.decode(
all_input_ids[token_offset:],
skip_special_tokens=False,
)
# get text
token_text = sequence_text[offset:]
# if text is utf-8
if token_text and token_text[-1] != "<EFBFBD>":
return token_text, None, None
else:
return "", offset, token_offset
return "", prefix_offset, read_offset
def check_initialized(self):
uninitialized_parameters = []