mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
* chore(neuron): update to optimum-neuron 0.3.0 Dependencies were changed accordingly, because Neuron SDK was updated to v2.24. * test: sample is not deterministic Also modify the temperature in decode test to avoid granite early stopping. * test(neuron): adjust expectations after graph changes * test(neuron): use greedy for stop sequences --------- Co-authored-by: David Corvoysier <david@huggingface.co>
53 lines
2.2 KiB
Python
53 lines
2.2 KiB
Python
from helpers import create_request
|
|
from text_generation_server.generator import NeuronGenerator
|
|
from text_generation_server.pb.generate_pb2 import Batch
|
|
|
|
|
|
def test_decode(neuron_model_config):
|
|
"""Verify that a decoding for a single request generates the expected output."""
|
|
config_name = neuron_model_config["name"]
|
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
|
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
|
for do_sample in [True, False]:
|
|
mode = "sample" if do_sample else "greedy"
|
|
print(f"{config_name}[{mode}]")
|
|
generated_text = _test_decode(config_name, generator, do_sample)
|
|
if not do_sample:
|
|
expected_text = {
|
|
"llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility",
|
|
"qwen2": " I was sitting in my room, staring at the clock, when a knock at the door. I",
|
|
"granite": "\n\nThis opening line is from George Orwell's dystopian novel, \"1",
|
|
}[config_name]
|
|
assert generated_text == expected_text
|
|
generator.clear()
|
|
|
|
|
|
def _test_decode(config_name, generator, do_sample):
|
|
input_text = (
|
|
"It was a bright cold day in April, and the clocks were striking thirteen."
|
|
)
|
|
max_new_tokens = 20
|
|
request = create_request(
|
|
id=0,
|
|
inputs=input_text,
|
|
max_new_tokens=max_new_tokens,
|
|
do_sample=do_sample,
|
|
temperature=0.9,
|
|
)
|
|
max_length = generator.model.neuron_config.sequence_length
|
|
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
|
generations, next_batch = generator.prefill(batch)
|
|
# We already generated one token: call decode max_new_tokens - 1 times
|
|
for _ in range(max_new_tokens - 1):
|
|
assert next_batch.size == 1
|
|
assert next_batch.max_tokens == max_length
|
|
assert len(generations) == 1
|
|
assert len(generations[0].tokens.ids) == 1
|
|
generations, next_batch = generator.decode([next_batch])
|
|
assert next_batch is None
|
|
assert len(generations) == 1
|
|
output = generations[0].generated_text
|
|
assert output.generated_tokens == max_new_tokens
|
|
assert output.finish_reason == 0
|
|
return output.text
|