mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
test(neuron): use greedy for stop sequences
This commit is contained in:
parent
4870824781
commit
7093a372bb
@ -29,15 +29,15 @@ async def test_model_single_request(tgi_service):
|
||||
assert response.generated_text == greedy_expectations[service_name]
|
||||
|
||||
# Greedy bounded with input
|
||||
response = await tgi_service.client.text_generation(
|
||||
greedy_response = await tgi_service.client.text_generation(
|
||||
"What is Deep Learning?",
|
||||
max_new_tokens=17,
|
||||
return_full_text=True,
|
||||
details=True,
|
||||
decoder_input_details=True,
|
||||
)
|
||||
assert response.details.generated_tokens == 17
|
||||
assert response.generated_text == prompt + greedy_expectations[service_name]
|
||||
assert greedy_response.details.generated_tokens == 17
|
||||
assert greedy_response.generated_text == prompt + greedy_expectations[service_name]
|
||||
|
||||
# Sampling
|
||||
response = await tgi_service.client.text_generation(
|
||||
@ -52,16 +52,12 @@ async def test_model_single_request(tgi_service):
|
||||
# The response must be different
|
||||
assert not response.startswith(greedy_expectations[service_name])
|
||||
|
||||
# Sampling with stop sequence (using one of the words returned from the previous test)
|
||||
stop_sequence = response.split(" ")[-5]
|
||||
# Greedy with stop sequence (using one of the words returned from the previous test)
|
||||
stop_sequence = greedy_response.generated_text.split(" ")[-5]
|
||||
response = await tgi_service.client.text_generation(
|
||||
"What is Deep Learning?",
|
||||
do_sample=True,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.2,
|
||||
do_sample=False,
|
||||
max_new_tokens=128,
|
||||
seed=42,
|
||||
stop_sequences=[stop_sequence],
|
||||
)
|
||||
assert response.endswith(stop_sequence)
|
||||
|
Loading…
Reference in New Issue
Block a user