From 62ff3816fb69b5f5ae5aaaca596abc22a1f92742 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 2 Jun 2023 15:58:47 +0200 Subject: [PATCH] fix tests --- integration-tests/conftest.py | 3 ++- integration-tests/models/test_bloom_560m.py | 2 ++ integration-tests/models/test_bloom_560m_sharded.py | 1 + integration-tests/models/test_flash_falcon.py | 2 ++ integration-tests/models/test_flash_llama.py | 5 ++++- integration-tests/models/test_flash_neox.py | 1 + integration-tests/models/test_flash_neox_sharded.py | 1 + integration-tests/models/test_flash_santacoder.py | 4 +++- integration-tests/models/test_flash_starcoder.py | 11 +++++++++-- integration-tests/models/test_mt0_base.py | 2 ++ integration-tests/models/test_t5_sharded.py | 1 + server/tests/models/test_bloom.py | 1 + server/tests/models/test_causal_lm.py | 1 + server/tests/models/test_santacoder.py | 2 ++ server/tests/models/test_seq2seq_lm.py | 1 + 15 files changed, 33 insertions(+), 5 deletions(-) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 902a7158..36d3eac3 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -332,7 +332,8 @@ def generate_load(): client: AsyncClient, prompt: str, max_new_tokens: int, n: int ) -> List[Response]: futures = [ - client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n) + client.generate(prompt, max_new_tokens=max_new_tokens, prefill_details=True) + for _ in range(n) ] return await asyncio.gather(*futures) diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py index 809250cb..d7fa6cdb 100644 --- a/integration-tests/models/test_bloom_560m.py +++ b/integration-tests/models/test_bloom_560m.py @@ -19,6 +19,7 @@ async def test_bloom_560m(bloom_560, response_snapshot): "Pour déguster un ortolan, il faut tout d'abord", max_new_tokens=10, top_p=0.9, + prefill_details=True, seed=0, ) @@ -40,6 +41,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot): truncate=5, typical_p=0.9, watermark=True, + prefill_details=True, seed=0, ) diff --git a/integration-tests/models/test_bloom_560m_sharded.py b/integration-tests/models/test_bloom_560m_sharded.py index ee67250a..0781b8d9 100644 --- a/integration-tests/models/test_bloom_560m_sharded.py +++ b/integration-tests/models/test_bloom_560m_sharded.py @@ -19,6 +19,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): "Pour déguster un ortolan, il faut tout d'abord", max_new_tokens=10, top_p=0.9, + prefill_details=True, seed=0, ) diff --git a/integration-tests/models/test_flash_falcon.py b/integration-tests/models/test_flash_falcon.py index e36a6a28..ebb67cba 100644 --- a/integration-tests/models/test_flash_falcon.py +++ b/integration-tests/models/test_flash_falcon.py @@ -19,6 +19,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot): response = await flash_falcon.generate( "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", max_new_tokens=10, + prefill_details=True, ) assert response.details.generated_tokens == 10 @@ -40,6 +41,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot): truncate=5, typical_p=0.9, watermark=True, + prefill_details=True, seed=0, ) diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py index edc847c1..fc16d172 100644 --- a/integration-tests/models/test_flash_llama.py +++ b/integration-tests/models/test_flash_llama.py @@ -16,7 +16,9 @@ async def flash_llama(flash_llama_handle): @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama(flash_llama, response_snapshot): - response = await flash_llama.generate("Test request", max_new_tokens=10) + response = await flash_llama.generate( + "Test request", max_new_tokens=10, prefill_details=True + ) assert response.details.generated_tokens == 10 assert response == response_snapshot @@ -37,6 +39,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot): truncate=5, typical_p=0.9, watermark=True, + prefill_details=True, seed=0, ) diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index daff7f0a..2c165c3a 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox, response_snapshot): response = await flash_neox.generate( "<|USER|>What's your mood today?<|ASSISTANT|>", max_new_tokens=10, + prefill_details=True, ) assert response.details.generated_tokens == 10 diff --git a/integration-tests/models/test_flash_neox_sharded.py b/integration-tests/models/test_flash_neox_sharded.py index a1aa0f07..b5106d11 100644 --- a/integration-tests/models/test_flash_neox_sharded.py +++ b/integration-tests/models/test_flash_neox_sharded.py @@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot): response = await flash_neox_sharded.generate( "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", max_new_tokens=10, + prefill_details=True, ) assert response.details.generated_tokens == 10 diff --git a/integration-tests/models/test_flash_santacoder.py b/integration-tests/models/test_flash_santacoder.py index a15a6439..90146969 100644 --- a/integration-tests/models/test_flash_santacoder.py +++ b/integration-tests/models/test_flash_santacoder.py @@ -15,7 +15,9 @@ async def flash_santacoder(flash_santacoder_handle): @pytest.mark.asyncio async def test_flash_santacoder(flash_santacoder, response_snapshot): - response = await flash_santacoder.generate("def print_hello", max_new_tokens=10) + response = await flash_santacoder.generate( + "def print_hello", max_new_tokens=10, prefill_details=True + ) assert response.details.generated_tokens == 10 assert response == response_snapshot diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py index 72b298c9..c5910fcc 100644 --- a/integration-tests/models/test_flash_starcoder.py +++ b/integration-tests/models/test_flash_starcoder.py @@ -16,7 +16,9 @@ async def flash_starcoder(flash_starcoder_handle): @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder(flash_starcoder, response_snapshot): - response = await flash_starcoder.generate("def print_hello", max_new_tokens=10) + response = await flash_starcoder.generate( + "def print_hello", max_new_tokens=10, prefill_details=True + ) assert response.details.generated_tokens == 10 assert response == response_snapshot @@ -26,7 +28,12 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot): @pytest.mark.private async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot): response = await flash_starcoder.generate( - "def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0 + "def print_hello", + max_new_tokens=60, + temperature=0.2, + top_p=0.95, + prefill_details=True, + seed=0, ) assert response.details.generated_tokens == 60 diff --git a/integration-tests/models/test_mt0_base.py b/integration-tests/models/test_mt0_base.py index 4ed95aad..29de2909 100644 --- a/integration-tests/models/test_mt0_base.py +++ b/integration-tests/models/test_mt0_base.py @@ -19,6 +19,7 @@ async def test_mt0_base(mt0_base, response_snapshot): "Why is the sky blue?", max_new_tokens=10, top_p=0.9, + prefill_details=True, seed=0, ) @@ -40,6 +41,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot): truncate=5, typical_p=0.9, watermark=True, + prefill_details=True, seed=0, ) diff --git a/integration-tests/models/test_t5_sharded.py b/integration-tests/models/test_t5_sharded.py index a2d84330..ff9b8eb7 100644 --- a/integration-tests/models/test_t5_sharded.py +++ b/integration-tests/models/test_t5_sharded.py @@ -18,6 +18,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot): response = await t5_sharded.generate( "Please answer the following question. What is the boiling point of Nitrogen?", max_new_tokens=10, + prefill_details=True, ) assert response == response_snapshot diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 590ba557..338fe053 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 3f28f5b3..0f9dab2c 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index bef8db38..fceec560 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="def", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, @@ -31,6 +32,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="defworld", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index a3199d02..299340f8 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters,