From 23b449d238da30c1443172365ee853a844416571 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Wed, 12 Mar 2025 08:37:28 +0000 Subject: [PATCH] test(neuron): simplify sampling test --- integration-tests/neuron/test_generate.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/integration-tests/neuron/test_generate.py b/integration-tests/neuron/test_generate.py index 6a1b4990..f0804356 100644 --- a/integration-tests/neuron/test_generate.py +++ b/integration-tests/neuron/test_generate.py @@ -49,17 +49,11 @@ async def test_model_single_request(tgi_service): max_new_tokens=128, seed=42, ) - sample_expectations = { - "gpt2": "Deep Learning", - "llama": "Deep Learning", - "mistral": "Deep learning", - "qwen2": "Deep Learning", - "granite": "Deep learning", - } - assert sample_expectations[service_name] in response + # The response must be different + assert not response.startswith(greedy_expectations[service_name]) - # Sampling with stop sequence - stop_sequence = sample_expectations[service_name][-5:] + # Sampling with stop sequence (using one of the words returned from the previous test) + stop_sequence = response.split(" ")[-5] response = await tgi_service.client.text_generation( "What is Deep Learning?", do_sample=True,