From 8645fd39e1474aded0b62d74de9bf406ea0257ae Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 12 Jul 2023 16:42:34 +0000 Subject: [PATCH] tests --- integration-tests/conftest.py | 8 +- .../test_flash_starcoder_gptq.json | 93 +++++ ...t_flash_starcoder_gptq_default_params.json | 393 ++++++++++++++++++ .../test_flash_starcoder_gptq_load.json | 374 +++++++++++++++++ .../models/test_flash_starcoder_gptq.py | 58 +++ 5 files changed, 922 insertions(+), 4 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json create mode 100644 integration-tests/models/test_flash_starcoder_gptq.py diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 8f59d75a..812b1d18 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -232,9 +232,9 @@ def launcher(event_loop): if num_shard is not None: args.extend(["--num-shard", str(num_shard)]) - if quantize: + if quantize is not None: args.append("--quantize") - args.append("bitsandbytes") + args.append(quantize) if trust_remote_code: args.append("--trust-remote-code") @@ -275,9 +275,9 @@ def launcher(event_loop): if num_shard is not None: args.extend(["--num-shard", str(num_shard)]) - if quantize: + if quantize is not None: args.append("--quantize") - args.append("bitsandbytes") + args.append(quantize) if trust_remote_code: args.append("--trust-remote-code") diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json new file mode 100644 index 00000000..8505c1db --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json @@ -0,0 +1,93 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 1459, + "logprob": -5.6289062, + "text": " print" + }, + { + "id": 81, + "logprob": -1.6005859, + "text": "_" + }, + { + "id": 7656, + "logprob": -5.9921875, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2262, + "logprob": -0.7705078, + "special": false, + "text": "():" + }, + { + "id": 284, + "logprob": -0.2590332, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": -0.39379883, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.61376953, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": -0.47338867, + "special": false, + "text": "Hello" + }, + { + "id": 10896, + "logprob": -1.5068359, + "special": false, + "text": " World" + }, + { + "id": 657, + "logprob": -0.80810547, + "special": false, + "text": "\")" + }, + { + "id": 203, + "logprob": -0.7397461, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": -0.35229492, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -1.0371094, + "special": false, + "text": "def" + } + ] + }, + "generated_text": "():\n print(\"Hello World\")\n\ndef" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json new file mode 100644 index 00000000..89e02c07 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -0,0 +1,393 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 60, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 1459, + "logprob": -5.6328125, + "text": " print" + }, + { + "id": 81, + "logprob": -1.6035156, + "text": "_" + }, + { + "id": 7656, + "logprob": -5.9882812, + "text": "hello" + } + ], + "seed": 0, + "tokens": [ + { + "id": 2262, + "logprob": -0.042999268, + "special": false, + "text": "():" + }, + { + "id": 284, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 10896, + "logprob": -0.38549805, + "special": false, + "text": " World" + }, + { + "id": 657, + "logprob": -0.5229492, + "special": false, + "text": "\")" + }, + { + "id": 203, + "logprob": -0.10632324, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -0.20141602, + "special": false, + "text": "def" + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7656, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 711, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 284, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.16027832, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 313, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 636, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 27, + "logprob": 0.0, + "special": false, + "text": ")" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7656, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 381, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 30, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 11442, + "logprob": 0.0, + "special": false, + "text": " age" + }, + { + "id": 711, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 284, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 313, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 636, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 313, + "logprob": -0.6328125, + "special": false, + "text": " \"" + }, + { + "id": 313, + "logprob": -1.7011719, + "special": false, + "text": " \"" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 596, + "logprob": 0.0, + "special": false, + "text": " str" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 381, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 490, + "logprob": 0.0, + "special": false, + "text": "))" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + } + ] + }, + "generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json new file mode 100644 index 00000000..0b3ad554 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json @@ -0,0 +1,374 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 1459, + "logprob": -5.6289062, + "text": " print" + }, + { + "id": 81, + "logprob": -1.6005859, + "text": "_" + }, + { + "id": 7656, + "logprob": -5.9921875, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2262, + "logprob": -0.7705078, + "special": false, + "text": "():" + }, + { + "id": 284, + "logprob": -0.2602539, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": -0.39282227, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.6113281, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": -0.4765625, + "special": false, + "text": "Hello" + }, + { + "id": 10896, + "logprob": -1.5068359, + "special": false, + "text": " World" + }, + { + "id": 657, + "logprob": -0.8154297, + "special": false, + "text": "\")" + }, + { + "id": 203, + "logprob": -0.7319336, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": -0.35229492, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -1.0380859, + "special": false, + "text": "def" + } + ] + }, + "generated_text": "():\n print(\"Hello World\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 1459, + "logprob": -5.6289062, + "text": " print" + }, + { + "id": 81, + "logprob": -1.6005859, + "text": "_" + }, + { + "id": 7656, + "logprob": -5.9921875, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2262, + "logprob": -0.7705078, + "special": false, + "text": "():" + }, + { + "id": 284, + "logprob": -0.2602539, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": -0.39282227, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.6113281, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": -0.4765625, + "special": false, + "text": "Hello" + }, + { + "id": 10896, + "logprob": -1.5068359, + "special": false, + "text": " World" + }, + { + "id": 657, + "logprob": -0.8154297, + "special": false, + "text": "\")" + }, + { + "id": 203, + "logprob": -0.7319336, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": -0.35229492, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -1.0380859, + "special": false, + "text": "def" + } + ] + }, + "generated_text": "():\n print(\"Hello World\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 1459, + "logprob": -5.6289062, + "text": " print" + }, + { + "id": 81, + "logprob": -1.6005859, + "text": "_" + }, + { + "id": 7656, + "logprob": -5.9921875, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2262, + "logprob": -0.7705078, + "special": false, + "text": "():" + }, + { + "id": 284, + "logprob": -0.2602539, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": -0.39282227, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.6113281, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": -0.4765625, + "special": false, + "text": "Hello" + }, + { + "id": 10896, + "logprob": -1.5068359, + "special": false, + "text": " World" + }, + { + "id": 657, + "logprob": -0.8154297, + "special": false, + "text": "\")" + }, + { + "id": 203, + "logprob": -0.7319336, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": -0.35229492, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -1.0380859, + "special": false, + "text": "def" + } + ] + }, + "generated_text": "():\n print(\"Hello World\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 1459, + "logprob": -5.6289062, + "text": " print" + }, + { + "id": 81, + "logprob": -1.6005859, + "text": "_" + }, + { + "id": 7656, + "logprob": -5.9921875, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2262, + "logprob": -0.7705078, + "special": false, + "text": "():" + }, + { + "id": 284, + "logprob": -0.2602539, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": -0.39282227, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.6113281, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": -0.4765625, + "special": false, + "text": "Hello" + }, + { + "id": 10896, + "logprob": -1.5068359, + "special": false, + "text": " World" + }, + { + "id": 657, + "logprob": -0.8154297, + "special": false, + "text": "\")" + }, + { + "id": 203, + "logprob": -0.7319336, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": -0.35229492, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -1.0380859, + "special": false, + "text": "def" + } + ] + }, + "generated_text": "():\n print(\"Hello World\")\n\ndef" + } +] diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py new file mode 100644 index 00000000..7c3f199d --- /dev/null +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -0,0 +1,58 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_santacoder_gptq_handle(launcher): + with launcher("Narsil/starcoder-gptq", num_shard=2, quantize="gptq") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_santacoder_gptq(flash_santacoder_gptq_handle): + await flash_santacoder_gptq_handle.health(300) + return flash_santacoder_gptq_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_santacoder_gptq(flash_santacoder_gptq, response_snapshot): + response = await flash_santacoder_gptq.generate( + 'def sum(L: List[int]):\n"""Sums all elements from the list L."""', max_new_tokens=40, decoder_input_details=True + ) + + # assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_santacoder_gptq_all_params(flash_santacoder_gptq, response_snapshot): + response = await flash_santacoder_gptq.generate( + 'def sum(L: List[int]):\n"""Sums all elements from the list L."""', + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + #assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_santacoder_gptq_load(flash_santacoder_gptq, generate_load, response_snapshot): + responses = await generate_load(flash_santacoder_gptq, 'def sum(L: List[int]):\n"""Sums all elements from the list L."""', max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot \ No newline at end of file