mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
fix tests
This commit is contained in:
parent
1dd0cf63df
commit
62ff3816fb
@ -332,7 +332,8 @@ def generate_load():
|
|||||||
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
|
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
|
||||||
) -> List[Response]:
|
) -> List[Response]:
|
||||||
futures = [
|
futures = [
|
||||||
client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)
|
client.generate(prompt, max_new_tokens=max_new_tokens, prefill_details=True)
|
||||||
|
for _ in range(n)
|
||||||
]
|
]
|
||||||
|
|
||||||
return await asyncio.gather(*futures)
|
return await asyncio.gather(*futures)
|
||||||
|
@ -19,6 +19,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
|
|||||||
"Pour déguster un ortolan, il faut tout d'abord",
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
|
prefill_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -40,6 +41,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
|||||||
truncate=5,
|
truncate=5,
|
||||||
typical_p=0.9,
|
typical_p=0.9,
|
||||||
watermark=True,
|
watermark=True,
|
||||||
|
prefill_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
|
|||||||
"Pour déguster un ortolan, il faut tout d'abord",
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
|
prefill_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
|
|||||||
response = await flash_falcon.generate(
|
response = await flash_falcon.generate(
|
||||||
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
prefill_details=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
@ -40,6 +41,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
|
|||||||
truncate=5,
|
truncate=5,
|
||||||
typical_p=0.9,
|
typical_p=0.9,
|
||||||
watermark=True,
|
watermark=True,
|
||||||
|
prefill_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,7 +16,9 @@ async def flash_llama(flash_llama_handle):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama(flash_llama, response_snapshot):
|
async def test_flash_llama(flash_llama, response_snapshot):
|
||||||
response = await flash_llama.generate("Test request", max_new_tokens=10)
|
response = await flash_llama.generate(
|
||||||
|
"Test request", max_new_tokens=10, prefill_details=True
|
||||||
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
@ -37,6 +39,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
|
|||||||
truncate=5,
|
truncate=5,
|
||||||
typical_p=0.9,
|
typical_p=0.9,
|
||||||
watermark=True,
|
watermark=True,
|
||||||
|
prefill_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
|
|||||||
response = await flash_neox.generate(
|
response = await flash_neox.generate(
|
||||||
"<|USER|>What's your mood today?<|ASSISTANT|>",
|
"<|USER|>What's your mood today?<|ASSISTANT|>",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
prefill_details=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
|
@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
|
|||||||
response = await flash_neox_sharded.generate(
|
response = await flash_neox_sharded.generate(
|
||||||
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
|
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
prefill_details=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
|
@ -15,7 +15,9 @@ async def flash_santacoder(flash_santacoder_handle):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
||||||
response = await flash_santacoder.generate("def print_hello", max_new_tokens=10)
|
response = await flash_santacoder.generate(
|
||||||
|
"def print_hello", max_new_tokens=10, prefill_details=True
|
||||||
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
@ -16,7 +16,9 @@ async def flash_starcoder(flash_starcoder_handle):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
||||||
response = await flash_starcoder.generate("def print_hello", max_new_tokens=10)
|
response = await flash_starcoder.generate(
|
||||||
|
"def print_hello", max_new_tokens=10, prefill_details=True
|
||||||
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
@ -26,7 +28,12 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
|||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):
|
async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):
|
||||||
response = await flash_starcoder.generate(
|
response = await flash_starcoder.generate(
|
||||||
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
|
"def print_hello",
|
||||||
|
max_new_tokens=60,
|
||||||
|
temperature=0.2,
|
||||||
|
top_p=0.95,
|
||||||
|
prefill_details=True,
|
||||||
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 60
|
assert response.details.generated_tokens == 60
|
||||||
|
@ -19,6 +19,7 @@ async def test_mt0_base(mt0_base, response_snapshot):
|
|||||||
"Why is the sky blue?",
|
"Why is the sky blue?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
|
prefill_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -40,6 +41,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
|||||||
truncate=5,
|
truncate=5,
|
||||||
typical_p=0.9,
|
typical_p=0.9,
|
||||||
watermark=True,
|
watermark=True,
|
||||||
|
prefill_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot):
|
|||||||
response = await t5_sharded.generate(
|
response = await t5_sharded.generate(
|
||||||
"Please answer the following question. What is the boiling point of Nitrogen?",
|
"Please answer the following question. What is the boiling point of Nitrogen?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
prefill_details=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
|
@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
|
@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="def",
|
inputs="def",
|
||||||
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
@ -31,6 +32,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
||||||
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
|
@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
|
Loading…
Reference in New Issue
Block a user