fix: marlin repeat scale for fp8 and bump snapshots

This commit is contained in:
drbh 2024-08-09 16:39:16 +00:00
parent df9eb38733
commit 3f12750a18
5 changed files with 188 additions and 219 deletions

View File

@ -11,79 +11,79 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.6015625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.515625,
"text": " request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 369, "id": 25,
"logprob": -2.1816406, "logprob": -2.1914062,
"special": false, "special": false,
"text": " for" "text": ":"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.7324219,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 16,
"logprob": -1.7900391, "logprob": -2.2753906,
"special": false, "special": false,
"text": "201" "text": "1"
}, },
{ {
"id": 24, "id": 13,
"logprob": -1.3554688, "logprob": -1.2070312,
"special": false, "special": false,
"text": "9" "text": "."
}, },
{ {
"id": 12, "id": 20,
"logprob": -2.0039062, "logprob": -2.765625,
"special": false, "special": false,
"text": "-" "text": "5"
}, },
{ {
"id": 2366, "id": 13,
"logprob": -0.4489746, "logprob": -1.1884766,
"special": false, "special": false,
"text": "202" "text": "."
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -1.5126953,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 12,
"logprob": -0.8100586, "logprob": -2.078125,
"special": false, "special": false,
"text": " school" "text": "-"
}, },
{ {
"id": 1060, "id": 1310,
"logprob": -0.013015747, "logprob": -0.7158203,
"special": false, "special": false,
"text": " year" "text": "rc"
},
{
"id": 16,
"logprob": -1.0234375,
"special": false,
"text": "1"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " for the 2019-2020 school year" "generated_text": ": 1.5.0-rc1"
} }

View File

@ -1,8 +1,8 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "stop_sequence",
"generated_tokens": 10, "generated_tokens": 5,
"prefill": [ "prefill": [
{ {
"id": 128000, "id": 128000,
@ -11,12 +11,12 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.5625, "logprob": -9.6015625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.375, "logprob": -10.515625,
"text": " request" "text": " request"
} }
], ],
@ -24,66 +24,36 @@
"tokens": [ "tokens": [
{ {
"id": 25, "id": 25,
"logprob": -0.8984375, "logprob": -0.81103516,
"special": false, "special": false,
"text": ":" "text": ":"
}, },
{ {
"id": 2209, "id": 923,
"logprob": -2.78125,
"special": false,
"text": " Is"
},
{
"id": 279,
"logprob": -0.6328125,
"special": false,
"text": " the"
},
{
"id": 734,
"logprob": -2.703125, "logprob": -2.703125,
"special": false, "special": false,
"text": " function" "text": " add"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
}, },
{ {
"id": 330, "id": 330,
"logprob": -0.34179688, "logprob": -0.1862793,
"special": false, "special": false,
"text": " \"" "text": " \""
}, },
{ {
"id": 4110, "id": 1985,
"logprob": -2.359375, "logprob": 0.0,
"special": false, "special": false,
"text": "Create" "text": "test"
},
{
"id": 7575,
"logprob": -2.1875,
"special": false,
"text": "Process"
},
{
"id": 1,
"logprob": -0.07910156,
"special": false,
"text": "\""
},
{
"id": 304,
"logprob": -0.83203125,
"special": false,
"text": " in"
},
{
"id": 12468,
"logprob": -1.8203125,
"special": false,
"text": " Win"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Test request: Is the function \"CreateProcess\" in Win" "generated_text": "Test request: add a \"test"
} }

View File

@ -12,81 +12,81 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.6015625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.515625,
"text": " request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 369, "id": 25,
"logprob": -2.1816406, "logprob": -2.1914062,
"special": false, "special": false,
"text": " for" "text": ":"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.7421875,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 16,
"logprob": -1.7988281, "logprob": -2.2753906,
"special": false, "special": false,
"text": "201" "text": "1"
}, },
{ {
"id": 24, "id": 13,
"logprob": -1.3535156, "logprob": -1.2041016,
"special": false, "special": false,
"text": "9" "text": "."
}, },
{ {
"id": 12, "id": 20,
"logprob": -2.0058594, "logprob": -2.7675781,
"special": false, "special": false,
"text": "-" "text": "5"
}, },
{ {
"id": 2366, "id": 13,
"logprob": -0.45410156, "logprob": -1.1884766,
"special": false, "special": false,
"text": "202" "text": "."
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -1.5244141,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 12,
"logprob": -0.8095703, "logprob": -2.0761719,
"special": false, "special": false,
"text": " school" "text": "-"
}, },
{ {
"id": 1060, "id": 1310,
"logprob": -0.013053894, "logprob": -0.71484375,
"special": false, "special": false,
"text": " year" "text": "rc"
},
{
"id": 16,
"logprob": -1.0244141,
"special": false,
"text": "1"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " for the 2019-2020 school year" "generated_text": ": 1.5.0-rc1"
}, },
{ {
"details": { "details": {
@ -101,81 +101,81 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.6015625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.515625,
"text": " request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 369, "id": 25,
"logprob": -2.1816406, "logprob": -2.1914062,
"special": false, "special": false,
"text": " for" "text": ":"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.7421875,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 16,
"logprob": -1.7988281, "logprob": -2.2753906,
"special": false, "special": false,
"text": "201" "text": "1"
}, },
{ {
"id": 24, "id": 13,
"logprob": -1.3535156, "logprob": -1.2041016,
"special": false, "special": false,
"text": "9" "text": "."
}, },
{ {
"id": 12, "id": 20,
"logprob": -2.0058594, "logprob": -2.7675781,
"special": false, "special": false,
"text": "-" "text": "5"
}, },
{ {
"id": 2366, "id": 13,
"logprob": -0.45410156, "logprob": -1.1884766,
"special": false, "special": false,
"text": "202" "text": "."
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -1.5244141,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 12,
"logprob": -0.8095703, "logprob": -2.0761719,
"special": false, "special": false,
"text": " school" "text": "-"
}, },
{ {
"id": 1060, "id": 1310,
"logprob": -0.013053894, "logprob": -0.71484375,
"special": false, "special": false,
"text": " year" "text": "rc"
},
{
"id": 16,
"logprob": -1.0244141,
"special": false,
"text": "1"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " for the 2019-2020 school year" "generated_text": ": 1.5.0-rc1"
}, },
{ {
"details": { "details": {
@ -190,81 +190,81 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.6015625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.515625,
"text": " request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 369, "id": 25,
"logprob": -2.1816406, "logprob": -2.1914062,
"special": false, "special": false,
"text": " for" "text": ":"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.7421875,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 16,
"logprob": -1.7988281, "logprob": -2.2753906,
"special": false, "special": false,
"text": "201" "text": "1"
}, },
{ {
"id": 24, "id": 13,
"logprob": -1.3535156, "logprob": -1.2041016,
"special": false, "special": false,
"text": "9" "text": "."
}, },
{ {
"id": 12, "id": 20,
"logprob": -2.0058594, "logprob": -2.7675781,
"special": false, "special": false,
"text": "-" "text": "5"
}, },
{ {
"id": 2366, "id": 13,
"logprob": -0.45410156, "logprob": -1.1884766,
"special": false, "special": false,
"text": "202" "text": "."
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -1.5244141,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 12,
"logprob": -0.8095703, "logprob": -2.0761719,
"special": false, "special": false,
"text": " school" "text": "-"
}, },
{ {
"id": 1060, "id": 1310,
"logprob": -0.013053894, "logprob": -0.71484375,
"special": false, "special": false,
"text": " year" "text": "rc"
},
{
"id": 16,
"logprob": -1.0244141,
"special": false,
"text": "1"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " for the 2019-2020 school year" "generated_text": ": 1.5.0-rc1"
}, },
{ {
"details": { "details": {
@ -279,80 +279,80 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.6015625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.515625,
"text": " request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 369, "id": 25,
"logprob": -2.1816406, "logprob": -2.1914062,
"special": false, "special": false,
"text": " for" "text": ":"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.7421875,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 16,
"logprob": -1.7988281, "logprob": -2.2753906,
"special": false, "special": false,
"text": "201" "text": "1"
}, },
{ {
"id": 24, "id": 13,
"logprob": -1.3535156, "logprob": -1.2041016,
"special": false, "special": false,
"text": "9" "text": "."
}, },
{ {
"id": 12, "id": 20,
"logprob": -2.0058594, "logprob": -2.7675781,
"special": false, "special": false,
"text": "-" "text": "5"
}, },
{ {
"id": 2366, "id": 13,
"logprob": -0.45410156, "logprob": -1.1884766,
"special": false, "special": false,
"text": "202" "text": "."
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -1.5244141,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 12,
"logprob": -0.8095703, "logprob": -2.0761719,
"special": false, "special": false,
"text": " school" "text": "-"
}, },
{ {
"id": 1060, "id": 1310,
"logprob": -0.013053894, "logprob": -0.71484375,
"special": false, "special": false,
"text": " year" "text": "rc"
},
{
"id": 16,
"logprob": -1.0244141,
"special": false,
"text": "1"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " for the 2019-2020 school year" "generated_text": ": 1.5.0-rc1"
} }
] ]

View File

@ -48,16 +48,15 @@ async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
# TODO: fix and re-enable
# @pytest.mark.release # @pytest.mark.release
# @pytest.mark.asyncio @pytest.mark.asyncio
# @pytest.mark.private @pytest.mark.private
# async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot): async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot):
# responses = await generate_load( responses = await generate_load(
# flash_llama_fp8, "Test request", max_new_tokens=10, n=4 flash_llama_fp8, "Test request", max_new_tokens=10, n=4
# ) )
# assert len(responses) == 4 assert len(responses) == 4
# assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])
# assert responses == response_snapshot assert responses == response_snapshot

View File

@ -39,7 +39,7 @@ class GPTQMarlinFP8Linear(nn.Module):
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
scales = scales.unsqueeze(0) scales = scales.unsqueeze(0)
if scales.shape[1] == 1: if scales.size(0) == 1:
out_features, in_features = qweight.shape out_features, in_features = qweight.shape
scales = scales.repeat(1, out_features) scales = scales.repeat(1, out_features)
qweight, scales = repack_fp8_for_marlin(qweight, scales) qweight, scales = repack_fp8_for_marlin(qweight, scales)