diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json new file mode 100644 index 00000000..36a2ff4d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json @@ -0,0 +1,94 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40844727, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27905273, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6118164, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68652344, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4619141, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.7993164, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.63134766, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23278809, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2294922, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json new file mode 100644 index 00000000..38117272 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json @@ -0,0 +1,394 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 60, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": 0, + "tokens": [ + { + "id": 2284, + "logprob": -0.296875, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.28125, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -0.79248047, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.61816406, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.0619812, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -0.4091797, + "special": false, + "text": "def" + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7670, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 444, + "logprob": -0.21655273, + "special": false, + "text": "name" + }, + { + "id": 45, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 731, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 303, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 332, + "logprob": -0.034698486, + "special": false, + "text": " \"" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 655, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 494, + "logprob": -0.20141602, + "special": false, + "text": " +" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 16013, + "logprob": 0.0, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7670, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 400, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 45, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 49, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 11505, + "logprob": 0.0, + "special": false, + "text": " age" + }, + { + "id": 731, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 303, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 655, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 3021, + "logprob": -0.5761719, + "special": false, + "text": " \"," + }, + { + "id": 863, + "logprob": 0.0, + "special": false, + "text": " you" + }, + { + "id": 904, + "logprob": 0.0, + "special": false, + "text": " are" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 615, + "logprob": 0.0, + "special": false, + "text": " str" + }, + { + "id": 45, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 400, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 46, + "logprob": 0.0, + "special": false, + "text": ")" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json new file mode 100644 index 00000000..9e82d4be --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json @@ -0,0 +1,378 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + } +] diff --git a/integration-tests/models/test_flash_starcoder2.py b/integration-tests/models/test_flash_starcoder2.py new file mode 100644 index 00000000..ea665b6c --- /dev/null +++ b/integration-tests/models/test_flash_starcoder2.py @@ -0,0 +1,55 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_starcoder2_handle(launcher): + with launcher("bigcode/starcoder2-3b", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_starcoder2(flash_starcoder2_handle): + await flash_starcoder2_handle.health(300) + return flash_starcoder2_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder2(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "def print_hello", + max_new_tokens=60, + temperature=0.2, + top_p=0.95, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 60 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder2_load( + flash_starcoder2, generate_load, response_snapshot +): + responses = await generate_load( + flash_starcoder2, "def print_hello", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot