mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Simplifying streaming decode.
This commit is contained in:
parent
d2a99b4294
commit
1aa31bb5cc
@ -6,8 +6,7 @@ from transformers import AutoTokenizer
|
|||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.private
|
def get_test_model():
|
||||||
def test_decode_streaming():
|
|
||||||
class TestModel(Model):
|
class TestModel(Model):
|
||||||
def batch_type(self):
|
def batch_type(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -20,7 +19,34 @@ def test_decode_streaming():
|
|||||||
model = TestModel(
|
model = TestModel(
|
||||||
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
|
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
|
||||||
)
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
def test_decode_streaming_english_spaces():
|
||||||
|
model = get_test_model()
|
||||||
|
truth = "Hello here, this is a simple test"
|
||||||
|
all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
|
||||||
|
assert (
|
||||||
|
all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded_text = ""
|
||||||
|
offset = 0
|
||||||
|
token_offset = 0
|
||||||
|
for i in range(len(all_input_ids)):
|
||||||
|
text, offset, token_offset = model.decode_token(
|
||||||
|
all_input_ids[: i + 1], offset, token_offset
|
||||||
|
)
|
||||||
|
decoded_text += text
|
||||||
|
|
||||||
|
assert decoded_text == truth
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
def test_decode_streaming_chinese_utf8():
|
||||||
|
model = get_test_model()
|
||||||
|
truth = "我很感谢你的热情"
|
||||||
all_input_ids = [
|
all_input_ids = [
|
||||||
30672,
|
30672,
|
||||||
232,
|
232,
|
||||||
@ -40,11 +66,9 @@ def test_decode_streaming():
|
|||||||
30993,
|
30993,
|
||||||
]
|
]
|
||||||
|
|
||||||
truth = "我很感谢你的热情"
|
|
||||||
|
|
||||||
decoded_text = ""
|
decoded_text = ""
|
||||||
offset = None
|
offset = 0
|
||||||
token_offset = None
|
token_offset = 0
|
||||||
for i in range(len(all_input_ids)):
|
for i in range(len(all_input_ids)):
|
||||||
text, offset, token_offset = model.decode_token(
|
text, offset, token_offset = model.decode_token(
|
||||||
all_input_ids[: i + 1], offset, token_offset
|
all_input_ids[: i + 1], offset, token_offset
|
||||||
|
@ -56,52 +56,25 @@ class Model(ABC):
|
|||||||
def decode_token(
|
def decode_token(
|
||||||
self,
|
self,
|
||||||
all_input_ids: List[int],
|
all_input_ids: List[int],
|
||||||
offset: Optional[int] = None,
|
prefix_offset: int = 0,
|
||||||
token_offset: Optional[int] = None,
|
read_offset: int = 0,
|
||||||
) -> Tuple[str, Optional[int], Optional[int]]:
|
) -> Tuple[str, int, int]:
|
||||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||||
if all_input_ids[-1] in self.all_special_ids:
|
|
||||||
return (
|
|
||||||
self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if token_offset is None:
|
# Compatibility layer for old None values.
|
||||||
token_offset = max(len(all_input_ids) - self.decode_buffer, 0)
|
if prefix_offset is None:
|
||||||
# left token buffer
|
prefix_offset = 0
|
||||||
if self.decode_buffer > 1:
|
if read_offset is None:
|
||||||
# Decode token_offset token minus last one and token_offset tokens
|
read_offset = 0
|
||||||
raw_texts = self.tokenizer.batch_decode(
|
|
||||||
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# default offset is only the last token
|
prefix_text = self.tokenizer.decode(all_input_ids[prefix_offset:read_offset])
|
||||||
offset = len(raw_texts[0])
|
new_text = self.tokenizer.decode(all_input_ids[prefix_offset:])
|
||||||
sequence_text = raw_texts[1]
|
|
||||||
else:
|
if len(new_text) > len(prefix_text) and "<EFBFBD>" not in new_text:
|
||||||
# Only decode the last token without using a token buffer
|
new_text = new_text[len(prefix_text) :]
|
||||||
sequence_text = self.tokenizer.decode(
|
return new_text, read_offset, len(all_input_ids)
|
||||||
all_input_ids[-1], skip_special_tokens=False
|
|
||||||
)
|
|
||||||
# no offset in this case
|
|
||||||
offset = 0
|
|
||||||
else:
|
else:
|
||||||
assert offset is not None
|
return "", prefix_offset, read_offset
|
||||||
sequence_text = self.tokenizer.decode(
|
|
||||||
all_input_ids[token_offset:],
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get text
|
|
||||||
token_text = sequence_text[offset:]
|
|
||||||
|
|
||||||
# if text is utf-8
|
|
||||||
if token_text and token_text[-1] != "<EFBFBD>":
|
|
||||||
return token_text, None, None
|
|
||||||
else:
|
|
||||||
return "", offset, token_offset
|
|
||||||
|
|
||||||
def check_initialized(self):
|
def check_initialized(self):
|
||||||
uninitialized_parameters = []
|
uninitialized_parameters = []
|
||||||
|
Loading…
Reference in New Issue
Block a user