mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Simplifying streaming decode.
This commit is contained in:
parent
d2a99b4294
commit
1aa31bb5cc
@ -6,8 +6,7 @@ from transformers import AutoTokenizer
|
||||
from text_generation_server.models import Model
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
def test_decode_streaming():
|
||||
def get_test_model():
|
||||
class TestModel(Model):
|
||||
def batch_type(self):
|
||||
raise NotImplementedError
|
||||
@ -20,7 +19,34 @@ def test_decode_streaming():
|
||||
model = TestModel(
|
||||
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
def test_decode_streaming_english_spaces():
|
||||
model = get_test_model()
|
||||
truth = "Hello here, this is a simple test"
|
||||
all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
|
||||
assert (
|
||||
all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
|
||||
decoded_text = ""
|
||||
offset = 0
|
||||
token_offset = 0
|
||||
for i in range(len(all_input_ids)):
|
||||
text, offset, token_offset = model.decode_token(
|
||||
all_input_ids[: i + 1], offset, token_offset
|
||||
)
|
||||
decoded_text += text
|
||||
|
||||
assert decoded_text == truth
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
def test_decode_streaming_chinese_utf8():
|
||||
model = get_test_model()
|
||||
truth = "我很感谢你的热情"
|
||||
all_input_ids = [
|
||||
30672,
|
||||
232,
|
||||
@ -40,11 +66,9 @@ def test_decode_streaming():
|
||||
30993,
|
||||
]
|
||||
|
||||
truth = "我很感谢你的热情"
|
||||
|
||||
decoded_text = ""
|
||||
offset = None
|
||||
token_offset = None
|
||||
offset = 0
|
||||
token_offset = 0
|
||||
for i in range(len(all_input_ids)):
|
||||
text, offset, token_offset = model.decode_token(
|
||||
all_input_ids[: i + 1], offset, token_offset
|
||||
|
@ -56,52 +56,25 @@ class Model(ABC):
|
||||
def decode_token(
|
||||
self,
|
||||
all_input_ids: List[int],
|
||||
offset: Optional[int] = None,
|
||||
token_offset: Optional[int] = None,
|
||||
) -> Tuple[str, Optional[int], Optional[int]]:
|
||||
prefix_offset: int = 0,
|
||||
read_offset: int = 0,
|
||||
) -> Tuple[str, int, int]:
|
||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||
if all_input_ids[-1] in self.all_special_ids:
|
||||
return (
|
||||
self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
if token_offset is None:
|
||||
token_offset = max(len(all_input_ids) - self.decode_buffer, 0)
|
||||
# left token buffer
|
||||
if self.decode_buffer > 1:
|
||||
# Decode token_offset token minus last one and token_offset tokens
|
||||
raw_texts = self.tokenizer.batch_decode(
|
||||
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
# Compatibility layer for old None values.
|
||||
if prefix_offset is None:
|
||||
prefix_offset = 0
|
||||
if read_offset is None:
|
||||
read_offset = 0
|
||||
|
||||
# default offset is only the last token
|
||||
offset = len(raw_texts[0])
|
||||
sequence_text = raw_texts[1]
|
||||
else:
|
||||
# Only decode the last token without using a token buffer
|
||||
sequence_text = self.tokenizer.decode(
|
||||
all_input_ids[-1], skip_special_tokens=False
|
||||
)
|
||||
# no offset in this case
|
||||
offset = 0
|
||||
prefix_text = self.tokenizer.decode(all_input_ids[prefix_offset:read_offset])
|
||||
new_text = self.tokenizer.decode(all_input_ids[prefix_offset:])
|
||||
|
||||
if len(new_text) > len(prefix_text) and "<EFBFBD>" not in new_text:
|
||||
new_text = new_text[len(prefix_text) :]
|
||||
return new_text, read_offset, len(all_input_ids)
|
||||
else:
|
||||
assert offset is not None
|
||||
sequence_text = self.tokenizer.decode(
|
||||
all_input_ids[token_offset:],
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# get text
|
||||
token_text = sequence_text[offset:]
|
||||
|
||||
# if text is utf-8
|
||||
if token_text and token_text[-1] != "<EFBFBD>":
|
||||
return token_text, None, None
|
||||
else:
|
||||
return "", offset, token_offset
|
||||
return "", prefix_offset, read_offset
|
||||
|
||||
def check_initialized(self):
|
||||
uninitialized_parameters = []
|
||||
|
Loading…
Reference in New Issue
Block a user