text-generation-inference/backends/neuron/tests/server/test_decode.py

60 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from helpers import create_request
from text_generation_server.generator import NeuronGenerator
from text_generation_server.pb.generate_pb2 import Batch
def test_decode(neuron_model_config):
"""Verify that a decoding for a single request generates the expected output."""
config_name = neuron_model_config["name"]
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
for do_sample in [True, False]:
mode = "sample" if do_sample else "greedy"
print(f"{config_name}[{mode}]")
_test_decode(config_name, generator, do_sample)
generator.clear()
def _test_decode(config_name, generator, do_sample):
input_text = (
"It was a bright cold day in April, and the clocks were striking thirteen."
)
max_new_tokens = 20
request = create_request(
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
)
max_length = generator.model.max_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times
for _ in range(max_new_tokens - 1):
assert next_batch.size == 1
assert next_batch.max_tokens == max_length
assert len(generations) == 1
assert len(generations[0].tokens.ids) == 1
generations, next_batch = generator.decode([next_batch])
assert next_batch is None
assert len(generations) == 1
output = generations[0].generated_text
assert output.generated_tokens == max_new_tokens
assert output.finish_reason == 0
if do_sample:
expected_text = {
"gpt2": " The sun was set",
"llama": "George Orwell, 1984",
"mistral": "The sky was",
"qwen2": " A young woman with",
"granite": "1984, George Orwell",
}[config_name]
assert expected_text in output.text
else:
print(output.text)
expected_text = {
"gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going',
"llama": " George Orwells classic dystopian novel, 1984, begins with this ominous sentence. The story",
"mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.",
"qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
"granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
}[config_name]
assert output.text == expected_text