text-generation-inference/server/tests/utils/test_tokens.py

45 lines
1.4 KiB
Python
Raw Normal View History

2023-03-07 17:52:22 +00:00
from text_generation_server.utils.tokens import (
2022-12-12 17:25:22 +00:00
StopSequenceCriteria,
StoppingCriteria,
FinishReason,
2022-12-08 17:49:33 +00:00
)
2022-12-12 17:25:22 +00:00
def test_stop_sequence_criteria():
2022-12-16 15:03:39 +00:00
criteria = StopSequenceCriteria("/test;")
2022-12-12 17:25:22 +00:00
2022-12-16 15:03:39 +00:00
assert not criteria("/")
assert not criteria("/test")
assert criteria("/test;")
assert not criteria("/test; ")
2022-12-12 17:25:22 +00:00
def test_stop_sequence_criteria_escape():
criteria = StopSequenceCriteria("<|stop|>")
assert not criteria("<")
assert not criteria("<|stop")
assert criteria("<|stop|>")
assert not criteria("<|stop|> ")
2022-12-16 15:03:39 +00:00
def test_stopping_criteria():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(65827, "/test") == (False, None)
assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
2022-12-12 17:25:22 +00:00
2022-12-16 15:03:39 +00:00
def test_stopping_criteria_eos():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
2022-12-12 17:25:22 +00:00
def test_stopping_criteria_max():
2022-12-16 15:03:39 +00:00
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)