test(neuron): simplify sampling test

This commit is contained in:
David Corvoysier 2025-03-12 08:37:28 +00:00
parent 2bb2bdb6a8
commit 23b449d238

View File

@ -49,17 +49,11 @@ async def test_model_single_request(tgi_service):
max_new_tokens=128, max_new_tokens=128,
seed=42, seed=42,
) )
sample_expectations = { # The response must be different
"gpt2": "Deep Learning", assert not response.startswith(greedy_expectations[service_name])
"llama": "Deep Learning",
"mistral": "Deep learning",
"qwen2": "Deep Learning",
"granite": "Deep learning",
}
assert sample_expectations[service_name] in response
# Sampling with stop sequence # Sampling with stop sequence (using one of the words returned from the previous test)
stop_sequence = sample_expectations[service_name][-5:] stop_sequence = response.split(" ")[-5]
response = await tgi_service.client.text_generation( response = await tgi_service.client.text_generation(
"What is Deep Learning?", "What is Deep Learning?",
do_sample=True, do_sample=True,