mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fix GPTQ for models which do not have float16 at the default dtype
Before this change GPTQ models would not work if the model's default data type is not `float16`. For example, Gemma GPTQ models would fail because the default dtype of Gemma is `bfloat16`. There are two issues: If the default `dtype` is not `float16`, the quantizer's `float16` parameters get converted to that dtype. The kernels cannot deal with non-`float16` types. The same applies to inputs of quantized ops. This is resolved by setting the dtype of gptq/awq-quantized models to `float16`.
This commit is contained in:
parent
9231098f3a
commit
6f30a13afa
@ -0,0 +1,89 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -9.640625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 604,
|
||||
"logprob": -2.4296875,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -2.4453125,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 2412,
|
||||
"logprob": -2.8632812,
|
||||
"special": false,
|
||||
"text": " following"
|
||||
},
|
||||
{
|
||||
"id": 235292,
|
||||
"logprob": -2.1328125,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 109,
|
||||
"logprob": -0.76660156,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 235287,
|
||||
"logprob": -1.3837891,
|
||||
"special": false,
|
||||
"text": "*"
|
||||
},
|
||||
{
|
||||
"id": 235248,
|
||||
"logprob": -1.9746094,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 199,
|
||||
"logprob": -1.4189453,
|
||||
"special": false,
|
||||
"text": "<strong>"
|
||||
},
|
||||
{
|
||||
"id": 1232,
|
||||
"logprob": -4.34375,
|
||||
"special": false,
|
||||
"text": "Name"
|
||||
},
|
||||
{
|
||||
"id": 208,
|
||||
"logprob": -0.8852539,
|
||||
"special": false,
|
||||
"text": "</strong>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||
}
|
@ -0,0 +1,89 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -9.65625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.3671875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 604,
|
||||
"logprob": -0.36938477,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 235248,
|
||||
"logprob": -1.8046875,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 235274,
|
||||
"logprob": -0.46240234,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 235284,
|
||||
"logprob": -1.7460938,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 235265,
|
||||
"logprob": -1.9443359,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 235284,
|
||||
"logprob": -1.4550781,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 235308,
|
||||
"logprob": -1.0205078,
|
||||
"special": false,
|
||||
"text": "5"
|
||||
},
|
||||
{
|
||||
"id": 235290,
|
||||
"logprob": -1.0283203,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 235274,
|
||||
"logprob": -1.2783203,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 235284,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request for 12.25-12"
|
||||
}
|
@ -0,0 +1,358 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -9.6484375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.359375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 604,
|
||||
"logprob": -2.4277344,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -2.4394531,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 2412,
|
||||
"logprob": -2.8613281,
|
||||
"special": false,
|
||||
"text": " following"
|
||||
},
|
||||
{
|
||||
"id": 235292,
|
||||
"logprob": -2.1523438,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 109,
|
||||
"logprob": -0.76220703,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 235287,
|
||||
"logprob": -1.3642578,
|
||||
"special": false,
|
||||
"text": "*"
|
||||
},
|
||||
{
|
||||
"id": 235248,
|
||||
"logprob": -2.0175781,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 199,
|
||||
"logprob": -1.4238281,
|
||||
"special": false,
|
||||
"text": "<strong>"
|
||||
},
|
||||
{
|
||||
"id": 1232,
|
||||
"logprob": -4.328125,
|
||||
"special": false,
|
||||
"text": "Name"
|
||||
},
|
||||
{
|
||||
"id": 208,
|
||||
"logprob": -0.8881836,
|
||||
"special": false,
|
||||
"text": "</strong>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -9.6484375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 604,
|
||||
"logprob": -2.4238281,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -2.4453125,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 2412,
|
||||
"logprob": -2.859375,
|
||||
"special": false,
|
||||
"text": " following"
|
||||
},
|
||||
{
|
||||
"id": 235292,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 109,
|
||||
"logprob": -0.7631836,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 235287,
|
||||
"logprob": -1.3642578,
|
||||
"special": false,
|
||||
"text": "*"
|
||||
},
|
||||
{
|
||||
"id": 235248,
|
||||
"logprob": -1.9960938,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 199,
|
||||
"logprob": -1.4179688,
|
||||
"special": false,
|
||||
"text": "<strong>"
|
||||
},
|
||||
{
|
||||
"id": 1232,
|
||||
"logprob": -4.3359375,
|
||||
"special": false,
|
||||
"text": "Name"
|
||||
},
|
||||
{
|
||||
"id": 208,
|
||||
"logprob": -0.8847656,
|
||||
"special": false,
|
||||
"text": "</strong>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -9.640625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.3671875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 604,
|
||||
"logprob": -2.4257812,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -2.4453125,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 2412,
|
||||
"logprob": -2.8789062,
|
||||
"special": false,
|
||||
"text": " following"
|
||||
},
|
||||
{
|
||||
"id": 235292,
|
||||
"logprob": -2.1367188,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 109,
|
||||
"logprob": -0.76171875,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 235287,
|
||||
"logprob": -1.3515625,
|
||||
"special": false,
|
||||
"text": "*"
|
||||
},
|
||||
{
|
||||
"id": 235248,
|
||||
"logprob": -1.9873047,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 199,
|
||||
"logprob": -1.4169922,
|
||||
"special": false,
|
||||
"text": "<strong>"
|
||||
},
|
||||
{
|
||||
"id": 1232,
|
||||
"logprob": -4.3320312,
|
||||
"special": false,
|
||||
"text": "Name"
|
||||
},
|
||||
{
|
||||
"id": 208,
|
||||
"logprob": -0.8930664,
|
||||
"special": false,
|
||||
"text": "</strong>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -9.6484375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.359375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 604,
|
||||
"logprob": -2.4179688,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -2.4492188,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 2412,
|
||||
"logprob": -2.8574219,
|
||||
"special": false,
|
||||
"text": " following"
|
||||
},
|
||||
{
|
||||
"id": 235292,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 109,
|
||||
"logprob": -0.7519531,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 235287,
|
||||
"logprob": -1.3623047,
|
||||
"special": false,
|
||||
"text": "*"
|
||||
},
|
||||
{
|
||||
"id": 235248,
|
||||
"logprob": -1.9707031,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 199,
|
||||
"logprob": -1.4267578,
|
||||
"special": false,
|
||||
"text": "<strong>"
|
||||
},
|
||||
{
|
||||
"id": 1232,
|
||||
"logprob": -4.3359375,
|
||||
"special": false,
|
||||
"text": "Name"
|
||||
},
|
||||
{
|
||||
"id": 208,
|
||||
"logprob": -0.88427734,
|
||||
"special": false,
|
||||
"text": "</strong>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||
}
|
||||
]
|
62
integration-tests/models/test_flash_gemma_gptq.py
Normal file
62
integration-tests/models/test_flash_gemma_gptq.py
Normal file
@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_gemma_gptq_handle(launcher):
|
||||
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_gemma_gptq(flash_gemma_gptq_handle):
|
||||
await flash_gemma_gptq_handle.health(300)
|
||||
return flash_gemma_gptq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot):
|
||||
response = await flash_gemma_gptq.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot):
|
||||
response = await flash_gemma_gptq.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma_gptq_load(
|
||||
flash_gemma_gptq, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_gemma_gptq, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
@ -263,9 +263,13 @@ def get_model(
|
||||
trust_remote_code: bool,
|
||||
) -> Model:
|
||||
if dtype is None:
|
||||
# Keep it as default for now and let
|
||||
# every model resolve their own default dtype.
|
||||
dtype = None
|
||||
if quantize in ["awq", "gptq"]:
|
||||
# These quantizers only work with float16 params.
|
||||
dtype = torch.float16
|
||||
else:
|
||||
# Keep it as default for now and let
|
||||
# every model resolve their own default dtype.
|
||||
dtype = None
|
||||
elif dtype == "float16":
|
||||
dtype = torch.float16
|
||||
elif dtype == "bfloat16":
|
||||
|
Loading…
Reference in New Issue
Block a user