From afb39404e13ed867306fdc39c933872653055856 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 21 Jul 2023 08:15:25 +0000 Subject: [PATCH] Getting closer to the non gptq test (stop sequence doesn't work). --- .../test_flash_llama_gptq.json | 89 ++--- .../test_flash_llama_gptq_all_params.json | 94 +++-- .../test_flash_llama_gptq_load.json | 356 ++++++++---------- .../models/test_flash_llama_gptq.py | 7 +- 4 files changed, 230 insertions(+), 316 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json index 8d2bcf75..e4ffb83b 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json @@ -10,94 +10,79 @@ "text": "" }, { - "id": 20628, - "logprob": -10.4296875, - "text": "Today" + "id": 4321, + "logprob": -9.59375, + "text": "Test" }, { - "id": 306, - "logprob": -2.4140625, - "text": "I" - }, - { - "id": 626, - "logprob": -1.8818359, - "text": "am" - }, - { - "id": 297, - "logprob": -4.4804688, - "text": "in" - }, - { - "id": 3444, - "logprob": -7.0820312, - "text": "France" + "id": 2009, + "logprob": -9.6640625, + "text": "request" } ], "seed": null, "tokens": [ { - "id": 29892, - "logprob": -1.2949219, + "id": 29918, + "logprob": -2.3867188, "special": false, - "text": "," + "text": "_" }, { - "id": 297, - "logprob": -1.9414062, + "id": 5338, + "logprob": -2.8183594, "special": false, - "text": " in" + "text": "uri" }, { - "id": 278, - "logprob": -0.75390625, + "id": 13, + "logprob": -1.6367188, "special": false, - "text": " the" + "text": "\n" }, { - "id": 7062, - "logprob": -2.9101562, + "id": 3057, + "logprob": -1.0527344, "special": false, - "text": " south" + "text": "Test" }, { - "id": 310, - "logprob": -1.0263672, + "id": 2009, + "logprob": -0.6542969, "special": false, - "text": " of" + "text": " request" }, { - "id": 278, - "logprob": -0.5751953, + "id": 29918, + "logprob": -0.056121826, "special": false, - "text": " the" + "text": "_" }, { - "id": 4234, - "logprob": -0.30273438, + "id": 5338, + "logprob": -0.01600647, "special": false, - "text": " country" + "text": "uri" }, { - "id": 29892, - "logprob": -0.69091797, + "id": 13, + "logprob": -0.87939453, "special": false, - "text": "," + "text": "\n" }, { - "id": 297, - "logprob": -1.1015625, + "id": 3057, + "logprob": -0.7529297, "special": false, - "text": " in" + "text": "Test" }, { - "id": 278, - "logprob": -0.5175781, + "id": 2009, + "logprob": -0.2980957, "special": false, - "text": " the" + "text": " request" } ] }, - "generated_text": ", in the south of the country, in the" + "generated_text": "_uri\nTest request_uri\nTest request" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json index 463ea40f..02713a00 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json @@ -10,89 +10,79 @@ "text": "" }, { - "id": 4272, - "logprob": -12.390625, - "text": "city" + "id": 4321, + "logprob": -9.6015625, + "text": "Test" }, { - "id": 310, - "logprob": -2.5292969, - "text": "of" - }, - { - "id": 3444, - "logprob": -11.25, - "text": "France" - }, - { - "id": 338, - "logprob": -4.953125, - "text": "is" + "id": 2009, + "logprob": -9.6640625, + "text": "request" } ], "seed": 0, "tokens": [ { - "id": 278, - "logprob": -0.1796875, + "id": 29899, + "logprob": -1.1640625, "special": false, - "text": " the" + "text": "-" }, { - "id": 12949, - "logprob": -2.2792969, + "id": 1454, + "logprob": -0.07543945, "special": false, - "text": " seat" + "text": "for" }, { - "id": 310, + "id": 29899, "logprob": 0.0, "special": false, - "text": " of" + "text": "-" }, { - "id": 263, - "logprob": -0.09301758, - "special": false, - "text": " a" - }, - { - "id": 5917, - "logprob": -1.3974609, - "special": false, - "text": " Roman" - }, - { - "id": 11865, + "id": 9342, "logprob": 0.0, "special": false, - "text": " Catholic" + "text": "comment" }, { - "id": 3190, + "id": 29901, "logprob": 0.0, "special": false, - "text": " arch" + "text": ":" }, { - "id": 28693, + "id": 396, + "logprob": -0.2956543, + "special": false, + "text": " #" + }, + { + "id": 29906, + "logprob": -0.52734375, + "special": false, + "text": "2" + }, + { + "id": 29900, + "logprob": -0.6899414, + "special": false, + "text": "0" + }, + { + "id": 29896, "logprob": 0.0, "special": false, - "text": "bishop" + "text": "1" }, { - "id": 29892, - "logprob": 0.0, + "id": 29946, + "logprob": -1.5068359, "special": false, - "text": "," - }, - { - "id": 1058, - "logprob": -0.9433594, - "special": false, - "text": " who" + "text": "4" } ] }, - "generated_text": "The capital city of France isthe seat of a Roman Catholic archbishop, who" + "generated_text": "Test request-for-comment: #2014" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json index 354c0534..88bfa4f9 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json @@ -11,96 +11,81 @@ "text": "" }, { - "id": 20628, - "logprob": -10.4296875, - "text": "Today" + "id": 4321, + "logprob": -9.6015625, + "text": "Test" }, { - "id": 306, - "logprob": -2.4179688, - "text": "I" - }, - { - "id": 626, - "logprob": -1.8876953, - "text": "am" - }, - { - "id": 297, - "logprob": -4.484375, - "text": "in" - }, - { - "id": 3444, - "logprob": -7.0820312, - "text": "France" + "id": 2009, + "logprob": -9.671875, + "text": "request" } ], "seed": null, "tokens": [ { - "id": 29892, - "logprob": -1.2958984, + "id": 29918, + "logprob": -2.3828125, "special": false, - "text": "," + "text": "_" }, { - "id": 297, - "logprob": -1.9423828, + "id": 5338, + "logprob": -2.8105469, "special": false, - "text": " in" + "text": "uri" }, { - "id": 278, - "logprob": -0.7475586, + "id": 13, + "logprob": -1.6396484, "special": false, - "text": " the" + "text": "\n" }, { - "id": 7062, - "logprob": -2.9101562, + "id": 3057, + "logprob": -1.0546875, "special": false, - "text": " south" + "text": "Test" }, { - "id": 310, - "logprob": -1.0380859, + "id": 2009, + "logprob": -0.6513672, "special": false, - "text": " of" + "text": " request" }, { - "id": 278, - "logprob": -0.5761719, + "id": 29918, + "logprob": -0.056365967, "special": false, - "text": " the" + "text": "_" }, { - "id": 4234, - "logprob": -0.30297852, + "id": 5338, + "logprob": -0.016082764, "special": false, - "text": " country" + "text": "uri" }, { - "id": 29892, - "logprob": -0.6933594, + "id": 13, + "logprob": -0.87841797, "special": false, - "text": "," + "text": "\n" }, { - "id": 297, - "logprob": -1.0966797, + "id": 3057, + "logprob": -0.7548828, "special": false, - "text": " in" + "text": "Test" }, { - "id": 278, - "logprob": -0.51708984, + "id": 2009, + "logprob": -0.29711914, "special": false, - "text": " the" + "text": " request" } ] }, - "generated_text": ", in the south of the country, in the" + "generated_text": "_uri\nTest request_uri\nTest request" }, { "details": { @@ -114,96 +99,81 @@ "text": "" }, { - "id": 20628, - "logprob": -10.53125, - "text": "Today" + "id": 4321, + "logprob": -9.6015625, + "text": "Test" }, { - "id": 306, - "logprob": -2.4609375, - "text": "I" - }, - { - "id": 626, - "logprob": -1.8857422, - "text": "am" - }, - { - "id": 297, - "logprob": -4.484375, - "text": "in" - }, - { - "id": 3444, - "logprob": -7.1015625, - "text": "France" + "id": 2009, + "logprob": -9.6640625, + "text": "request" } ], "seed": null, "tokens": [ { - "id": 29892, - "logprob": -1.2910156, + "id": 29918, + "logprob": -2.3828125, "special": false, - "text": "," + "text": "_" }, { - "id": 297, - "logprob": -1.9375, + "id": 5338, + "logprob": -2.828125, "special": false, - "text": " in" + "text": "uri" }, { - "id": 278, - "logprob": -0.7416992, + "id": 13, + "logprob": -1.6386719, "special": false, - "text": " the" + "text": "\n" }, { - "id": 7062, - "logprob": -2.90625, + "id": 3057, + "logprob": -1.0527344, "special": false, - "text": " south" + "text": "Test" }, { - "id": 310, - "logprob": -1.0439453, + "id": 2009, + "logprob": -0.6542969, "special": false, - "text": " of" + "text": " request" }, { - "id": 278, - "logprob": -0.5654297, + "id": 29918, + "logprob": -0.055877686, "special": false, - "text": " the" + "text": "_" }, { - "id": 4234, - "logprob": -0.3125, + "id": 5338, + "logprob": -0.016021729, "special": false, - "text": " country" + "text": "uri" }, { - "id": 29892, - "logprob": -0.69384766, + "id": 13, + "logprob": -0.8769531, "special": false, - "text": "," + "text": "\n" }, { - "id": 297, - "logprob": -1.0976562, + "id": 3057, + "logprob": -0.7583008, "special": false, - "text": " in" + "text": "Test" }, { - "id": 278, - "logprob": -0.51416016, + "id": 2009, + "logprob": -0.29833984, "special": false, - "text": " the" + "text": " request" } ] }, - "generated_text": ", in the south of the country, in the" + "generated_text": "_uri\nTest request_uri\nTest request" }, { "details": { @@ -217,96 +187,81 @@ "text": "" }, { - "id": 20628, - "logprob": -10.53125, - "text": "Today" + "id": 4321, + "logprob": -9.6015625, + "text": "Test" }, { - "id": 306, - "logprob": -2.4609375, - "text": "I" - }, - { - "id": 626, - "logprob": -1.8857422, - "text": "am" - }, - { - "id": 297, - "logprob": -4.484375, - "text": "in" - }, - { - "id": 3444, - "logprob": -7.1015625, - "text": "France" + "id": 2009, + "logprob": -9.671875, + "text": "request" } ], "seed": null, "tokens": [ { - "id": 29892, - "logprob": -1.2910156, + "id": 29918, + "logprob": -2.3847656, "special": false, - "text": "," + "text": "_" }, { - "id": 297, - "logprob": -1.9384766, + "id": 5338, + "logprob": -2.8144531, "special": false, - "text": " in" + "text": "uri" }, { - "id": 278, - "logprob": -0.7426758, + "id": 13, + "logprob": -1.6396484, "special": false, - "text": " the" + "text": "\n" }, { - "id": 7062, - "logprob": -2.9042969, + "id": 3057, + "logprob": -1.0527344, "special": false, - "text": " south" + "text": "Test" }, { - "id": 310, - "logprob": -1.0439453, + "id": 2009, + "logprob": -0.65478516, "special": false, - "text": " of" + "text": " request" }, { - "id": 278, - "logprob": -0.56103516, + "id": 29918, + "logprob": -0.056243896, "special": false, - "text": " the" + "text": "_" }, { - "id": 4234, - "logprob": -0.31323242, + "id": 5338, + "logprob": -0.016143799, "special": false, - "text": " country" + "text": "uri" }, { - "id": 29892, - "logprob": -0.6982422, + "id": 13, + "logprob": -0.8808594, "special": false, - "text": "," + "text": "\n" }, { - "id": 297, - "logprob": -1.0976562, + "id": 3057, + "logprob": -0.75341797, "special": false, - "text": " in" + "text": "Test" }, { - "id": 278, - "logprob": -0.52001953, + "id": 2009, + "logprob": -0.2956543, "special": false, - "text": " the" + "text": " request" } ] }, - "generated_text": ", in the south of the country, in the" + "generated_text": "_uri\nTest request_uri\nTest request" }, { "details": { @@ -320,95 +275,80 @@ "text": "" }, { - "id": 20628, - "logprob": -10.53125, - "text": "Today" + "id": 4321, + "logprob": -9.6015625, + "text": "Test" }, { - "id": 306, - "logprob": -2.4609375, - "text": "I" - }, - { - "id": 626, - "logprob": -1.8857422, - "text": "am" - }, - { - "id": 297, - "logprob": -4.484375, - "text": "in" - }, - { - "id": 3444, - "logprob": -7.1015625, - "text": "France" + "id": 2009, + "logprob": -9.6640625, + "text": "request" } ], "seed": null, "tokens": [ { - "id": 29892, - "logprob": -1.2910156, + "id": 29918, + "logprob": -2.3769531, "special": false, - "text": "," + "text": "_" }, { - "id": 297, - "logprob": -1.9394531, + "id": 5338, + "logprob": -2.8183594, "special": false, - "text": " in" + "text": "uri" }, { - "id": 278, - "logprob": -0.74121094, + "id": 13, + "logprob": -1.6396484, "special": false, - "text": " the" + "text": "\n" }, { - "id": 7062, - "logprob": -2.90625, + "id": 3057, + "logprob": -1.0546875, "special": false, - "text": " south" + "text": "Test" }, { - "id": 310, - "logprob": -1.0439453, + "id": 2009, + "logprob": -0.65478516, "special": false, - "text": " of" + "text": " request" }, { - "id": 278, - "logprob": -0.56591797, + "id": 29918, + "logprob": -0.05557251, "special": false, - "text": " the" + "text": "_" }, { - "id": 4234, - "logprob": -0.31713867, + "id": 5338, + "logprob": -0.01612854, "special": false, - "text": " country" + "text": "uri" }, { - "id": 29892, - "logprob": -0.69140625, + "id": 13, + "logprob": -0.8730469, "special": false, - "text": "," + "text": "\n" }, { - "id": 297, - "logprob": -1.0957031, + "id": 3057, + "logprob": -0.7519531, "special": false, - "text": " in" + "text": "Test" }, { - "id": 278, - "logprob": -0.52001953, + "id": 2009, + "logprob": -0.29785156, "special": false, - "text": " the" + "text": " request" } ] }, - "generated_text": ", in the south of the country, in the" + "generated_text": "_uri\nTest request_uri\nTest request" } ] diff --git a/integration-tests/models/test_flash_llama_gptq.py b/integration-tests/models/test_flash_llama_gptq.py index 91f71a7f..bc525f6d 100644 --- a/integration-tests/models/test_flash_llama_gptq.py +++ b/integration-tests/models/test_flash_llama_gptq.py @@ -1,4 +1,3 @@ - import pytest @@ -18,7 +17,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle): @pytest.mark.private async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): response = await flash_llama_gptq.generate( - "Today I am in France", max_new_tokens=10, decoder_input_details=True + "Test request", max_new_tokens=10, decoder_input_details=True ) assert response.details.generated_tokens == 10 @@ -29,7 +28,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): @pytest.mark.private async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): response = await flash_llama_gptq.generate( - "The capital city of France is", + "Test request", max_new_tokens=10, repetition_penalty=1.2, return_full_text=True, @@ -50,7 +49,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq_load(flash_llama_gptq, generate_load, response_snapshot): - responses = await generate_load(flash_llama_gptq, "Today I am in France", max_new_tokens=10, n=4) + responses = await generate_load(flash_llama_gptq, "Test request", max_new_tokens=10, n=4) assert len(responses) == 4 assert all([r.generated_text == responses[0].generated_text for r in responses])