test(neuron): use greedy for stop sequences

This commit is contained in:
David Corvoysier 2025-08-04 15:18:31 +00:00
parent 4870824781
commit 7093a372bb

View File

@ -29,15 +29,15 @@ async def test_model_single_request(tgi_service):
assert response.generated_text == greedy_expectations[service_name] assert response.generated_text == greedy_expectations[service_name]
# Greedy bounded with input # Greedy bounded with input
response = await tgi_service.client.text_generation( greedy_response = await tgi_service.client.text_generation(
"What is Deep Learning?", "What is Deep Learning?",
max_new_tokens=17, max_new_tokens=17,
return_full_text=True, return_full_text=True,
details=True, details=True,
decoder_input_details=True, decoder_input_details=True,
) )
assert response.details.generated_tokens == 17 assert greedy_response.details.generated_tokens == 17
assert response.generated_text == prompt + greedy_expectations[service_name] assert greedy_response.generated_text == prompt + greedy_expectations[service_name]
# Sampling # Sampling
response = await tgi_service.client.text_generation( response = await tgi_service.client.text_generation(
@ -52,16 +52,12 @@ async def test_model_single_request(tgi_service):
# The response must be different # The response must be different
assert not response.startswith(greedy_expectations[service_name]) assert not response.startswith(greedy_expectations[service_name])
# Sampling with stop sequence (using one of the words returned from the previous test) # Greedy with stop sequence (using one of the words returned from the previous test)
stop_sequence = response.split(" ")[-5] stop_sequence = greedy_response.generated_text.split(" ")[-5]
response = await tgi_service.client.text_generation( response = await tgi_service.client.text_generation(
"What is Deep Learning?", "What is Deep Learning?",
do_sample=True, do_sample=False,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
max_new_tokens=128, max_new_tokens=128,
seed=42,
stop_sequences=[stop_sequence], stop_sequences=[stop_sequence],
) )
assert response.endswith(stop_sequence) assert response.endswith(stop_sequence)