mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
56 lines
2.6 KiB
Python
56 lines
2.6 KiB
Python
|
from helpers import create_request
|
|||
|
from text_generation_server.generator import NeuronGenerator
|
|||
|
from text_generation_server.pb.generate_pb2 import Batch
|
|||
|
|
|||
|
|
|||
|
def test_decode(neuron_model_config):
|
|||
|
"""Verify that a decoding for a single request generates the expected output."""
|
|||
|
config_name = neuron_model_config["name"]
|
|||
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
|||
|
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
|||
|
for do_sample in [True, False]:
|
|||
|
mode = "sample" if do_sample else "greedy"
|
|||
|
print(f"{config_name}[{mode}]")
|
|||
|
_test_decode(config_name, generator, do_sample)
|
|||
|
generator.clear()
|
|||
|
|
|||
|
|
|||
|
def _test_decode(config_name, generator, do_sample):
|
|||
|
input_text = "It was a bright cold day in April, and the clocks were striking thirteen."
|
|||
|
max_new_tokens = 20
|
|||
|
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample)
|
|||
|
max_length = generator.model.max_length
|
|||
|
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
|||
|
generations, next_batch = generator.prefill(batch)
|
|||
|
# We already generated one token: call decode max_new_tokens - 1 times
|
|||
|
for _ in range(max_new_tokens - 1):
|
|||
|
assert next_batch.size == 1
|
|||
|
assert next_batch.max_tokens == max_length
|
|||
|
assert len(generations) == 1
|
|||
|
assert len(generations[0].tokens.ids) == 1
|
|||
|
generations, next_batch = generator.decode([next_batch])
|
|||
|
assert next_batch is None
|
|||
|
assert len(generations) == 1
|
|||
|
output = generations[0].generated_text
|
|||
|
assert output.generated_tokens == max_new_tokens
|
|||
|
assert output.finish_reason == 0
|
|||
|
if do_sample:
|
|||
|
expected_text = {
|
|||
|
"gpt2": " The sun was set",
|
|||
|
"llama": "George Orwell, 1984",
|
|||
|
"mistral": "The sky was",
|
|||
|
"qwen2": " A young woman with",
|
|||
|
"granite": "1984, George Orwell",
|
|||
|
}[config_name]
|
|||
|
assert expected_text in output.text
|
|||
|
else:
|
|||
|
print(output.text)
|
|||
|
expected_text = {
|
|||
|
"gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going',
|
|||
|
"llama": " George Orwell’s classic dystopian novel, 1984, begins with this ominous sentence. The story",
|
|||
|
"mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.",
|
|||
|
"qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
|
|||
|
"granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
|
|||
|
}[config_name]
|
|||
|
assert output.text == expected_text
|