mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
test: sample is not deterministic
Also modify the temperature in decode test to avoid granite early stopping.
This commit is contained in:
parent
b11d663ca0
commit
e0aa213411
@ -11,7 +11,14 @@ def test_decode(neuron_model_config):
|
|||||||
for do_sample in [True, False]:
|
for do_sample in [True, False]:
|
||||||
mode = "sample" if do_sample else "greedy"
|
mode = "sample" if do_sample else "greedy"
|
||||||
print(f"{config_name}[{mode}]")
|
print(f"{config_name}[{mode}]")
|
||||||
_test_decode(config_name, generator, do_sample)
|
generated_text = _test_decode(config_name, generator, do_sample)
|
||||||
|
if not do_sample:
|
||||||
|
expected_text = {
|
||||||
|
"llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility",
|
||||||
|
"qwen2": " I was sitting in my room, staring at the clock, when a knock at the door. I",
|
||||||
|
"granite": "\n\nThis opening line is from George Orwell's dystopian novel, \"1",
|
||||||
|
}[config_name]
|
||||||
|
assert generated_text == expected_text
|
||||||
generator.clear()
|
generator.clear()
|
||||||
|
|
||||||
|
|
||||||
@ -21,7 +28,11 @@ def _test_decode(config_name, generator, do_sample):
|
|||||||
)
|
)
|
||||||
max_new_tokens = 20
|
max_new_tokens = 20
|
||||||
request = create_request(
|
request = create_request(
|
||||||
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
|
id=0,
|
||||||
|
inputs=input_text,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
do_sample=do_sample,
|
||||||
|
temperature=0.9,
|
||||||
)
|
)
|
||||||
max_length = generator.model.neuron_config.sequence_length
|
max_length = generator.model.neuron_config.sequence_length
|
||||||
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
||||||
@ -38,18 +49,4 @@ def _test_decode(config_name, generator, do_sample):
|
|||||||
output = generations[0].generated_text
|
output = generations[0].generated_text
|
||||||
assert output.generated_tokens == max_new_tokens
|
assert output.generated_tokens == max_new_tokens
|
||||||
assert output.finish_reason == 0
|
assert output.finish_reason == 0
|
||||||
if do_sample:
|
return output.text
|
||||||
expected_text = {
|
|
||||||
"llama": " I sat alone in the café",
|
|
||||||
"qwen2": " The air was so still",
|
|
||||||
"granite": "1984, George Orwell",
|
|
||||||
}[config_name]
|
|
||||||
assert expected_text in output.text
|
|
||||||
else:
|
|
||||||
print(output.text)
|
|
||||||
expected_text = {
|
|
||||||
"llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility",
|
|
||||||
"qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
|
|
||||||
"granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
|
|
||||||
}[config_name]
|
|
||||||
assert output.text == expected_text
|
|
||||||
|
@ -44,24 +44,18 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
|
|||||||
# because of static batching
|
# because of static batching
|
||||||
assert next_batch.max_tokens == batch_size * max_length
|
assert next_batch.max_tokens == batch_size * max_length
|
||||||
assert len(generations) == batch_size
|
assert len(generations) == batch_size
|
||||||
if do_sample:
|
|
||||||
expectations = {
|
|
||||||
"llama": [358, " I"],
|
|
||||||
"qwen2": [576, " The"],
|
|
||||||
"granite": [308, " ("],
|
|
||||||
}[config_name]
|
|
||||||
else:
|
|
||||||
expectations = {
|
expectations = {
|
||||||
"llama": [578, " The"],
|
"llama": [578, " The"],
|
||||||
"qwen2": [358, " I"],
|
"qwen2": [358, " I"],
|
||||||
"granite": [203, "\n"],
|
"granite": [203, "\n"],
|
||||||
}[config_name]
|
}[config_name]
|
||||||
|
# Greedy mode should always generate the same output
|
||||||
|
if not do_sample:
|
||||||
for g in generations:
|
for g in generations:
|
||||||
tokens = g.tokens
|
tokens = g.tokens
|
||||||
assert tokens.ids[0] == expectations[0]
|
assert tokens.ids[0] == expectations[0]
|
||||||
assert tokens.texts[0] == expectations[1]
|
assert tokens.texts[0] == expectations[1]
|
||||||
|
|
||||||
|
|
||||||
def test_prefill_truncate(neuron_model_config):
|
def test_prefill_truncate(neuron_model_config):
|
||||||
config_name = neuron_model_config["name"]
|
config_name = neuron_model_config["name"]
|
||||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||||
@ -88,8 +82,8 @@ def test_prefill_truncate(neuron_model_config):
|
|||||||
# be different because of the truncation
|
# be different because of the truncation
|
||||||
expectations = {
|
expectations = {
|
||||||
"llama": [" He", "iens", "\x08", " He"],
|
"llama": [" He", "iens", "\x08", " He"],
|
||||||
"qwen2": [" He", " The", " He", " He"],
|
"qwen2": [" He", "<|endoftext|>", " ", " The"],
|
||||||
"granite": ["\n", "\n", " I", " He"],
|
"granite": ["\n", "\n", "\n", "\n"],
|
||||||
}[config_name]
|
}[config_name]
|
||||||
for i, g in enumerate(generations):
|
for i, g in enumerate(generations):
|
||||||
tokens = g.tokens
|
tokens = g.tokens
|
||||||
|
Loading…
Reference in New Issue
Block a user