Fix GPTQ for models which do not have float16 at the default dtype

Before this change GPTQ models would not work if the model's default
data type is not `float16`. For example, Gemma GPTQ models would fail
because the default dtype of Gemma is `bfloat16`. There are two issues:

If the default `dtype` is not `float16`, the quantizer's `float16`
parameters get converted to that dtype. The kernels cannot deal
with non-`float16` types. The same applies to inputs of quantized ops.

This is resolved by setting the dtype of gptq/awq-quantized models to
`float16`.
This commit is contained in:
Daniël de Kok 2024-05-25 08:48:01 +00:00
parent 9231098f3a
commit 6f30a13afa
5 changed files with 605 additions and 3 deletions

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.640625,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4296875,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4453125,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.8632812,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1328125,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.76660156,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3837891,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -1.9746094,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4189453,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.34375,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.8852539,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.65625,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.3671875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 604,
"logprob": -0.36938477,
"special": false,
"text": " for"
},
{
"id": 235248,
"logprob": -1.8046875,
"special": false,
"text": " "
},
{
"id": 235274,
"logprob": -0.46240234,
"special": false,
"text": "1"
},
{
"id": 235284,
"logprob": -1.7460938,
"special": false,
"text": "2"
},
{
"id": 235265,
"logprob": -1.9443359,
"special": false,
"text": "."
},
{
"id": 235284,
"logprob": -1.4550781,
"special": false,
"text": "2"
},
{
"id": 235308,
"logprob": -1.0205078,
"special": false,
"text": "5"
},
{
"id": 235290,
"logprob": -1.0283203,
"special": false,
"text": "-"
},
{
"id": 235274,
"logprob": -1.2783203,
"special": false,
"text": "1"
},
{
"id": 235284,
"logprob": 0.0,
"special": false,
"text": "2"
}
],
"top_tokens": null
},
"generated_text": "Test request for 12.25-12"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.359375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4277344,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4394531,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.8613281,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1523438,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.76220703,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3642578,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -2.0175781,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4238281,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.328125,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.8881836,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4238281,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4453125,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.859375,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1445312,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.7631836,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3642578,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -1.9960938,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4179688,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.3359375,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.8847656,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.640625,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.3671875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4257812,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4453125,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.8789062,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1367188,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.76171875,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3515625,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -1.9873047,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4169922,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.3320312,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.8930664,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.359375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4179688,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4492188,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.8574219,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1445312,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.7519531,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3623047,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -1.9707031,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4267578,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.3359375,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.88427734,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
}
]

View File

@ -0,0 +1,62 @@
import pytest
@pytest.fixture(scope="module")
def flash_gemma_gptq_handle(launcher):
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_gemma_gptq(flash_gemma_gptq_handle):
await flash_gemma_gptq_handle.health(300)
return flash_gemma_gptq_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot):
response = await flash_gemma_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot):
response = await flash_gemma_gptq.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq_load(
flash_gemma_gptq, generate_load, response_snapshot
):
responses = await generate_load(
flash_gemma_gptq, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -263,9 +263,13 @@ def get_model(
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
if dtype is None: if dtype is None:
# Keep it as default for now and let if quantize in ["awq", "gptq"]:
# every model resolve their own default dtype. # These quantizers only work with float16 params.
dtype = None dtype = torch.float16
else:
# Keep it as default for now and let
# every model resolve their own default dtype.
dtype = None
elif dtype == "float16": elif dtype == "float16":
dtype = torch.float16 dtype = torch.float16
elif dtype == "bfloat16": elif dtype == "bfloat16":