2023-08-28 09:43:47 +00:00
|
|
|
import torch
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.utils.tokens import (
|
2022-12-12 17:25:22 +00:00
|
|
|
StopSequenceCriteria,
|
|
|
|
StoppingCriteria,
|
2023-02-03 11:43:37 +00:00
|
|
|
FinishReason,
|
2023-08-28 09:43:47 +00:00
|
|
|
batch_top_tokens,
|
2024-01-19 14:12:04 +00:00
|
|
|
make_tokenizer_optional,
|
2022-12-08 17:49:33 +00:00
|
|
|
)
|
2024-01-19 14:12:04 +00:00
|
|
|
from transformers import AutoTokenizer
|
2022-12-08 17:49:33 +00:00
|
|
|
|
2022-12-12 17:25:22 +00:00
|
|
|
def test_stop_sequence_criteria():
|
2022-12-16 15:03:39 +00:00
|
|
|
criteria = StopSequenceCriteria("/test;")
|
2022-12-12 17:25:22 +00:00
|
|
|
|
2022-12-16 15:03:39 +00:00
|
|
|
assert not criteria("/")
|
|
|
|
assert not criteria("/test")
|
|
|
|
assert criteria("/test;")
|
|
|
|
assert not criteria("/test; ")
|
2022-12-12 17:25:22 +00:00
|
|
|
|
|
|
|
|
2023-04-05 17:37:41 +00:00
|
|
|
def test_stop_sequence_criteria_escape():
|
|
|
|
criteria = StopSequenceCriteria("<|stop|>")
|
|
|
|
|
|
|
|
assert not criteria("<")
|
|
|
|
assert not criteria("<|stop")
|
|
|
|
assert criteria("<|stop|>")
|
|
|
|
assert not criteria("<|stop|> ")
|
|
|
|
|
|
|
|
|
2022-12-16 15:03:39 +00:00
|
|
|
def test_stopping_criteria():
|
|
|
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
|
|
|
assert criteria(65827, "/test") == (False, None)
|
2023-02-03 11:43:37 +00:00
|
|
|
assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
|
2022-12-12 17:25:22 +00:00
|
|
|
|
|
|
|
|
2022-12-16 15:03:39 +00:00
|
|
|
def test_stopping_criteria_eos():
|
|
|
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
|
|
|
assert criteria(1, "") == (False, None)
|
2023-02-03 11:43:37 +00:00
|
|
|
assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
|
2022-12-12 17:25:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_stopping_criteria_max():
|
2022-12-16 15:03:39 +00:00
|
|
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
|
|
|
assert criteria(1, "") == (False, None)
|
|
|
|
assert criteria(1, "") == (False, None)
|
|
|
|
assert criteria(1, "") == (False, None)
|
|
|
|
assert criteria(1, "") == (False, None)
|
2023-02-03 11:43:37 +00:00
|
|
|
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
2023-08-28 09:43:47 +00:00
|
|
|
|
2023-09-27 10:22:09 +00:00
|
|
|
|
2023-08-28 09:43:47 +00:00
|
|
|
def test_batch_top_tokens():
|
|
|
|
top_n_tokens = [0, 2, 3, 4, 5]
|
|
|
|
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
2023-09-27 10:22:09 +00:00
|
|
|
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
|
2023-08-28 09:43:47 +00:00
|
|
|
|
2023-09-27 10:22:09 +00:00
|
|
|
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
|
|
|
top_n_tokens, top_n_tokens_tensor, inp_logprobs
|
|
|
|
)
|
2023-08-28 09:43:47 +00:00
|
|
|
|
|
|
|
assert topn_tok_ids[0] == []
|
|
|
|
assert topn_tok_ids[1] == [0, 3]
|
|
|
|
assert topn_tok_ids[2] == [0, 3, 1, 4]
|
|
|
|
assert topn_tok_ids[3] == [0, 3, 1, 4]
|
|
|
|
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
|
|
|
|
|
|
|
|
assert topn_tok_logprobs[0] == []
|
|
|
|
assert topn_tok_logprobs[1] == [-1, -2]
|
|
|
|
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
|
|
|
|
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
|
|
|
|
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
|
2024-01-19 14:12:04 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_pass_through_tokenizer():
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
'meta-llama/Llama-2-7b-chat-hf',
|
|
|
|
revision=None,
|
|
|
|
padding_side="left",
|
|
|
|
truncation_side="left",
|
|
|
|
)
|
|
|
|
tokenizer.pad_token_id = 2
|
|
|
|
make_tokenizer_optional(tokenizer)
|
|
|
|
|
|
|
|
input = ["1, 1724, 338, 6483, 6509, 29973", "?"]
|
|
|
|
tokenized_inputs = tokenizer(
|
|
|
|
input,
|
|
|
|
return_tensors="pt",
|
|
|
|
padding="max_length",
|
|
|
|
return_token_type_ids=False,
|
|
|
|
truncation=True,
|
|
|
|
max_length=1024,
|
|
|
|
)
|
|
|
|
assert tokenized_inputs['input_ids'].size() == torch.Size([2, 1024])
|
|
|
|
assert torch.equal(tokenized_inputs['input_ids'][0][1018:], torch.tensor([1, 1724, 338, 6483, 6509, 29973]))
|
|
|
|
assert torch.equal(tokenized_inputs['input_ids'][1][1023:], torch.tensor([tokenizer.pad_token_id]))
|
|
|
|
decoded_tokens = tokenizer.decode(tokenized_inputs["input_ids"][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
|
|
assert decoded_tokens.split(',')[1018:] == ['1', '1724', '338', '6483', '6509', '29973']
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
test_pass_through_tokenizer()
|