diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json index 53055e42..5e537bb7 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json @@ -1,193 +1,194 @@ { - "generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L", "details": { "best_of_sequences": null, "finish_reason": "length", "generated_tokens": 20, - "seed": null, "prefill": [ { "id": 589, - "text": "def", - "logprob": null + "logprob": null, + "text": "def" }, { "id": 3226, - "text": " ge", - "logprob": -9.0234375 + "logprob": -8.5859375, + "text": " ge" }, { "id": 21017, - "text": "ometric", - "logprob": -9.0859375 + "logprob": -7.5859375, + "text": "ometric" }, { "id": 81, - "text": "_", - "logprob": -0.25878906 + "logprob": -0.2668457, + "text": "_" }, { "id": 6009, - "text": "mean", - "logprob": -2.2109375 + "logprob": -1.6416016, + "text": "mean" }, { "id": 26, - "text": "(", - "logprob": -0.30371094 + "logprob": -0.22705078, + "text": "(" }, { "id": 62, - "text": "L", - "logprob": -5.6054688 + "logprob": -5.2304688, + "text": "L" }, { "id": 44, - "text": ":", - "logprob": -3.0722656 + "logprob": -3.0976562, + "text": ":" }, { "id": 1682, - "text": " List", - "logprob": -0.6879883 + "logprob": -1.1044922, + "text": " List" }, { "id": 77, - "text": "[", - "logprob": -0.38500977 + "logprob": -0.14294434, + "text": "[" }, { "id": 1808, - "text": "float", - "logprob": -0.984375 + "logprob": -0.32299805, + "text": "float" }, { "id": 10794, - "text": "]):", - "logprob": -2.5351562 + "logprob": -2.8164062, + "text": "]):" } ], + "seed": null, "tokens": [ { "id": 284, - "text": "\n ", - "logprob": -1.1738281, - "special": false + "logprob": -0.1282959, + "special": false, + "text": "\n " }, { - "id": 442, - "text": " return", - "logprob": -0.95947266, - "special": false + "id": 1524, + "logprob": -0.97998047, + "special": false, + "text": " \"\"\"" }, { - "id": 3632, - "text": " sum", - "logprob": -1.4199219, - "special": false + "id": 284, + "logprob": -0.7006836, + "special": false, + "text": "\n " }, { - "id": 26, - "text": "(", - "logprob": -0.085876465, - "special": false + "id": 14883, + "logprob": -2.1933594, + "special": false, + "text": " Calculate" }, { - "id": 62, - "text": "L", - "logprob": -0.09875488, - "special": false - }, - { - "id": 27, - "text": ")", - "logprob": -0.30517578, - "special": false - }, - { - "id": 517, - "text": " /", - "logprob": -0.42089844, - "special": false - }, - { - "id": 2069, - "text": " len", - "logprob": -0.042053223, - "special": false - }, - { - "id": 26, - "text": "(", - "logprob": -0.0011806488, - "special": false - }, - { - "id": 62, - "text": "L", - "logprob": -0.0005259514, - "special": false - }, - { - "id": 27, - "text": ")", - "logprob": -0.0017633438, - "special": false - }, - { - "id": 478, - "text": "\n\n", - "logprob": -0.69189453, - "special": false - }, - { - "id": 203, - "text": "\n", - "logprob": -0.041870117, - "special": false - }, - { - "id": 589, - "text": "def", - "logprob": -0.27856445, - "special": false + "id": 322, + "logprob": -0.2697754, + "special": false, + "text": " the" }, { "id": 3226, - "text": " ge", - "logprob": -1.7255859, - "special": false + "logprob": -0.0836792, + "special": false, + "text": " ge" }, { "id": 21017, - "text": "ometric", - "logprob": -0.011291504, - "special": false + "logprob": -0.018737793, + "special": false, + "text": "ometric" }, { - "id": 81, - "text": "_", - "logprob": -0.008430481, - "special": false + "id": 5651, + "logprob": -0.028640747, + "special": false, + "text": " mean" }, { - "id": 6009, - "text": "mean", - "logprob": -0.025787354, - "special": false + "id": 432, + "logprob": -0.29467773, + "special": false, + "text": " of" }, { - "id": 26, - "text": "(", - "logprob": -0.073913574, - "special": false + "id": 312, + "logprob": -0.31518555, + "special": false, + "text": " a" }, { - "id": 62, - "text": "L", - "logprob": -0.09967041, - "special": false + "id": 1149, + "logprob": -0.20605469, + "special": false, + "text": " list" + }, + { + "id": 432, + "logprob": -0.23254395, + "special": false, + "text": " of" + }, + { + "id": 7515, + "logprob": -0.4489746, + "special": false, + "text": " numbers" + }, + { + "id": 32, + "logprob": -0.6044922, + "special": false, + "text": "." + }, + { + "id": 446, + "logprob": -0.63964844, + "special": false, + "text": "\n\n " + }, + { + "id": 499, + "logprob": -1.1953125, + "special": false, + "text": " :" + }, + { + "id": 753, + "logprob": -0.03515625, + "special": false, + "text": "param" + }, + { + "id": 498, + "logprob": -0.06311035, + "special": false, + "text": " L" + }, + { + "id": 44, + "logprob": -0.003414154, + "special": false, + "text": ":" + }, + { + "id": 1682, + "logprob": -1.3310547, + "special": false, + "text": " List" } - ] - } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json index 1ace3814..bf0f5146 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -11,57 +11,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5898438, "text": "ometric" }, { "id": 81, - "logprob": -0.25830078, + "logprob": -0.26586914, "text": "_" }, { "id": 6009, - "logprob": -2.1875, + "logprob": -1.6347656, "text": "mean" }, { "id": 26, - "logprob": -0.30004883, + "logprob": -0.22705078, "text": "(" }, { "id": 62, - "logprob": -5.6171875, + "logprob": -5.2382812, "text": "L" }, { "id": 44, - "logprob": -3.078125, + "logprob": -3.0996094, "text": ":" }, { "id": 1682, - "logprob": -0.68066406, + "logprob": -1.1025391, "text": " List" }, { "id": 77, - "logprob": -0.38745117, + "logprob": -0.14294434, "text": "[" }, { "id": 1808, - "logprob": -0.9453125, + "logprob": -0.32226562, "text": "float" }, { "id": 10794, - "logprob": -2.5371094, + "logprob": -2.8164062, "text": "]):" } ], @@ -69,19 +69,19 @@ "tokens": [ { "id": 284, - "logprob": -0.051635742, + "logprob": 0.0, "special": false, "text": "\n " }, { "id": 442, - "logprob": 0.0, + "logprob": -1.3134766, "special": false, "text": " return" }, { "id": 11665, - "logprob": -1.2236328, + "logprob": -0.10021973, "special": false, "text": " reduce" }, @@ -129,7 +129,7 @@ }, { "id": 319, - "logprob": 0.0, + "logprob": -0.42871094, "special": false, "text": " *" }, @@ -158,36 +158,37 @@ "text": ")" }, { - "id": 203, - "logprob": -0.12695312, - "special": false, - "text": "\n" - }, - { - "id": 203, + "id": 1115, "logprob": 0.0, "special": false, - "text": "\n" + "text": " **" }, { - "id": 589, + "id": 308, "logprob": 0.0, "special": false, - "text": "def" + "text": " (" }, { - "id": 3226, + "id": 35, "logprob": 0.0, "special": false, - "text": " ge" + "text": "1" }, { - "id": 21017, + "id": 32, + "logprob": -0.31323242, + "special": false, + "text": "." + }, + { + "id": 34, "logprob": 0.0, "special": false, - "text": "ometric" + "text": "0" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric" + "generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json index 5381ce5a..46a21ed8 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json @@ -12,57 +12,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5820312, "text": "ometric" }, { "id": 81, - "logprob": -0.25927734, + "logprob": -0.26708984, "text": "_" }, { "id": 6009, - "logprob": -2.25, + "logprob": -1.6386719, "text": "mean" }, { "id": 26, - "logprob": -0.30126953, + "logprob": -0.22717285, "text": "(" }, { "id": 62, - "logprob": -5.7539062, + "logprob": -5.234375, "text": "L" }, { "id": 44, - "logprob": -3.0878906, + "logprob": -3.1015625, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -1.1083984, "text": " List" }, { "id": 77, - "logprob": -0.3918457, + "logprob": -0.14294434, "text": "[" }, { "id": 1808, - "logprob": -0.8798828, + "logprob": -0.32592773, "text": "float" }, { "id": 10794, - "logprob": -2.4980469, + "logprob": -2.8164062, "text": "]):" } ], @@ -70,67 +70,68 @@ "tokens": [ { "id": 284, - "logprob": -1.1533203, + "logprob": -0.12817383, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.91796875, + "id": 1524, + "logprob": -0.9863281, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.3291016, + "id": 284, + "logprob": -0.7011719, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.08062744, + "id": 14883, + "logprob": -2.2050781, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.097717285, + "id": 322, + "logprob": -0.2668457, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.29003906, + "id": 3226, + "logprob": -0.08465576, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.34958984, + "id": 21017, + "logprob": -0.019012451, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.03829956, + "id": 5651, + "logprob": -0.028625488, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011987686, + "id": 432, + "logprob": -0.29418945, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.00050878525, + "id": 312, + "logprob": -0.3161621, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" }, { "details": { @@ -145,57 +146,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.59375, "text": "ometric" }, { "id": 81, - "logprob": -0.25878906, + "logprob": -0.26953125, "text": "_" }, { "id": 6009, - "logprob": -2.2109375, + "logprob": -1.640625, "text": "mean" }, { "id": 26, - "logprob": -0.30371094, + "logprob": -0.22705078, "text": "(" }, { "id": 62, - "logprob": -5.6054688, + "logprob": -5.234375, "text": "L" }, { "id": 44, - "logprob": -3.0722656, + "logprob": -3.1132812, "text": ":" }, { "id": 1682, - "logprob": -0.6879883, + "logprob": -1.1123047, "text": " List" }, { "id": 77, - "logprob": -0.38500977, + "logprob": -0.14294434, "text": "[" }, { "id": 1808, - "logprob": -0.984375, + "logprob": -0.32299805, "text": "float" }, { "id": 10794, - "logprob": -2.5351562, + "logprob": -2.8164062, "text": "]):" } ], @@ -203,67 +204,68 @@ "tokens": [ { "id": 284, - "logprob": -1.1738281, + "logprob": -0.12854004, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.9584961, + "id": 1524, + "logprob": -0.9897461, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.4169922, + "id": 284, + "logprob": -0.69970703, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.085876465, + "id": 14883, + "logprob": -2.2050781, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.0982666, + "id": 322, + "logprob": -0.2668457, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.3022461, + "id": 3226, + "logprob": -0.08496094, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.40504883, + "id": 21017, + "logprob": -0.019012451, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.041656494, + "id": 5651, + "logprob": -0.029037476, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011844635, + "id": 432, + "logprob": -0.2939453, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.0005264282, + "id": 312, + "logprob": -0.31591797, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" }, { "details": { @@ -278,57 +280,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5859375, "text": "ometric" }, { "id": 81, - "logprob": -0.25927734, + "logprob": -0.26586914, "text": "_" }, { "id": 6009, - "logprob": -2.25, + "logprob": -1.6347656, "text": "mean" }, { "id": 26, - "logprob": -0.30126953, + "logprob": -0.22766113, "text": "(" }, { "id": 62, - "logprob": -5.7539062, + "logprob": -5.2265625, "text": "L" }, { "id": 44, - "logprob": -3.0878906, + "logprob": -3.0976562, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -1.1025391, "text": " List" }, { "id": 77, - "logprob": -0.3918457, + "logprob": -0.1427002, "text": "[" }, { "id": 1808, - "logprob": -0.8798828, + "logprob": -0.32592773, "text": "float" }, { "id": 10794, - "logprob": -2.4980469, + "logprob": -2.8164062, "text": "]):" } ], @@ -336,67 +338,68 @@ "tokens": [ { "id": 284, - "logprob": -1.1533203, + "logprob": -0.13012695, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.9165039, + "id": 1524, + "logprob": -0.98046875, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.328125, + "id": 284, + "logprob": -0.69921875, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.07946777, + "id": 14883, + "logprob": -2.1992188, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.09820557, + "id": 322, + "logprob": -0.2668457, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.28930664, + "id": 3226, + "logprob": -0.083496094, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.34592773, + "id": 21017, + "logprob": -0.01902771, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.038330078, + "id": 5651, + "logprob": -0.029006958, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011940002, + "id": 432, + "logprob": -0.29248047, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.00050878525, + "id": 312, + "logprob": -0.3161621, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" }, { "details": { @@ -411,57 +414,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5859375, "text": "ometric" }, { "id": 81, - "logprob": -0.25927734, + "logprob": -0.26904297, "text": "_" }, { "id": 6009, - "logprob": -2.25, + "logprob": -1.6386719, "text": "mean" }, { "id": 26, - "logprob": -0.30126953, + "logprob": -0.22705078, "text": "(" }, { "id": 62, - "logprob": -5.7539062, + "logprob": -5.234375, "text": "L" }, { "id": 44, - "logprob": -3.0878906, + "logprob": -3.1132812, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -1.1074219, "text": " List" }, { "id": 77, - "logprob": -0.3918457, + "logprob": -0.14477539, "text": "[" }, { "id": 1808, - "logprob": -0.8798828, + "logprob": -0.3256836, "text": "float" }, { "id": 10794, - "logprob": -2.4980469, + "logprob": -2.8027344, "text": "]):" } ], @@ -469,66 +472,67 @@ "tokens": [ { "id": 284, - "logprob": -1.1533203, + "logprob": -0.12915039, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.91259766, + "id": 1524, + "logprob": -0.98535156, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.3251953, + "id": 284, + "logprob": -0.69921875, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.08062744, + "id": 14883, + "logprob": -2.2011719, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.09906006, + "id": 322, + "logprob": -0.26708984, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.28979492, + "id": 3226, + "logprob": -0.08502197, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.35958984, + "id": 21017, + "logprob": -0.019012451, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.038604736, + "id": 5651, + "logprob": -0.028625488, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011901855, + "id": 432, + "logprob": -0.29589844, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.0005078316, + "id": 312, + "logprob": -0.31591797, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" } ] diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 22d03adf..81041046 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -69,9 +69,17 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = qzeros.to(device=weights.device) - g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") - g_idx = g_idx.to(device=weights.device) - bits, groupsize, _ = weights._get_gptq_params() + bits, groupsize, _, quant_method, = weights._get_gptq_params() + if quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") + g_idx = g_idx.to(device=weights.device) + elif quant_method == "awq": + g_idx = None + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) from text_generation_server.utils.layers import HAS_EXLLAMA