mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
test(neuron): use greedy for stop sequences
This commit is contained in:
parent
4870824781
commit
7093a372bb
@ -29,15 +29,15 @@ async def test_model_single_request(tgi_service):
|
|||||||
assert response.generated_text == greedy_expectations[service_name]
|
assert response.generated_text == greedy_expectations[service_name]
|
||||||
|
|
||||||
# Greedy bounded with input
|
# Greedy bounded with input
|
||||||
response = await tgi_service.client.text_generation(
|
greedy_response = await tgi_service.client.text_generation(
|
||||||
"What is Deep Learning?",
|
"What is Deep Learning?",
|
||||||
max_new_tokens=17,
|
max_new_tokens=17,
|
||||||
return_full_text=True,
|
return_full_text=True,
|
||||||
details=True,
|
details=True,
|
||||||
decoder_input_details=True,
|
decoder_input_details=True,
|
||||||
)
|
)
|
||||||
assert response.details.generated_tokens == 17
|
assert greedy_response.details.generated_tokens == 17
|
||||||
assert response.generated_text == prompt + greedy_expectations[service_name]
|
assert greedy_response.generated_text == prompt + greedy_expectations[service_name]
|
||||||
|
|
||||||
# Sampling
|
# Sampling
|
||||||
response = await tgi_service.client.text_generation(
|
response = await tgi_service.client.text_generation(
|
||||||
@ -52,16 +52,12 @@ async def test_model_single_request(tgi_service):
|
|||||||
# The response must be different
|
# The response must be different
|
||||||
assert not response.startswith(greedy_expectations[service_name])
|
assert not response.startswith(greedy_expectations[service_name])
|
||||||
|
|
||||||
# Sampling with stop sequence (using one of the words returned from the previous test)
|
# Greedy with stop sequence (using one of the words returned from the previous test)
|
||||||
stop_sequence = response.split(" ")[-5]
|
stop_sequence = greedy_response.generated_text.split(" ")[-5]
|
||||||
response = await tgi_service.client.text_generation(
|
response = await tgi_service.client.text_generation(
|
||||||
"What is Deep Learning?",
|
"What is Deep Learning?",
|
||||||
do_sample=True,
|
do_sample=False,
|
||||||
top_k=50,
|
|
||||||
top_p=0.9,
|
|
||||||
repetition_penalty=1.2,
|
|
||||||
max_new_tokens=128,
|
max_new_tokens=128,
|
||||||
seed=42,
|
|
||||||
stop_sequences=[stop_sequence],
|
stop_sequences=[stop_sequence],
|
||||||
)
|
)
|
||||||
assert response.endswith(stop_sequence)
|
assert response.endswith(stop_sequence)
|
||||||
|
Loading…
Reference in New Issue
Block a user