mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fix GPTQ for models which do not have float16 at the default dtype
Before this change GPTQ models would not work if the model's default data type is not `float16`. For example, Gemma GPTQ models would fail because the default dtype of Gemma is `bfloat16`. There are two issues: If the default `dtype` is not `float16`, the quantizer's `float16` parameters get converted to that dtype. The kernels cannot deal with non-`float16` types. The same applies to inputs of quantized ops. This is resolved by setting the dtype of gptq/awq-quantized models to `float16`.
This commit is contained in:
parent
9231098f3a
commit
6f30a13afa
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.640625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.34375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4296875,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4453125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.8632812,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1328125,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.76660156,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3837891,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.9746094,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4189453,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.34375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.8852539,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.65625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.3671875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -0.36938477,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.8046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235274,
|
||||||
|
"logprob": -0.46240234,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235284,
|
||||||
|
"logprob": -1.7460938,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235265,
|
||||||
|
"logprob": -1.9443359,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235284,
|
||||||
|
"logprob": -1.4550781,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235308,
|
||||||
|
"logprob": -1.0205078,
|
||||||
|
"special": false,
|
||||||
|
"text": "5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235290,
|
||||||
|
"logprob": -1.0283203,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235274,
|
||||||
|
"logprob": -1.2783203,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235284,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request for 12.25-12"
|
||||||
|
}
|
@ -0,0 +1,358 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.6484375,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.359375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4277344,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4394531,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.8613281,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1523438,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.76220703,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3642578,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -2.0175781,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4238281,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.328125,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.8881836,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.6484375,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.34375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4238281,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4453125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.859375,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.7631836,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3642578,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.9960938,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4179688,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.3359375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.8847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.640625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.3671875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4257812,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4453125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.8789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1367188,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.76171875,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3515625,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.9873047,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4169922,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.3320312,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.8930664,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.6484375,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.359375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4179688,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4492188,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.8574219,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.7519531,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3623047,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.9707031,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4267578,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.3359375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.88427734,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
}
|
||||||
|
]
|
62
integration-tests/models/test_flash_gemma_gptq.py
Normal file
62
integration-tests/models/test_flash_gemma_gptq.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_gemma_gptq_handle(launcher):
|
||||||
|
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_gemma_gptq(flash_gemma_gptq_handle):
|
||||||
|
await flash_gemma_gptq_handle.health(300)
|
||||||
|
return flash_gemma_gptq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot):
|
||||||
|
response = await flash_gemma_gptq.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot):
|
||||||
|
response = await flash_gemma_gptq.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma_gptq_load(
|
||||||
|
flash_gemma_gptq, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_gemma_gptq, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -263,9 +263,13 @@ def get_model(
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
# Keep it as default for now and let
|
if quantize in ["awq", "gptq"]:
|
||||||
# every model resolve their own default dtype.
|
# These quantizers only work with float16 params.
|
||||||
dtype = None
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
# Keep it as default for now and let
|
||||||
|
# every model resolve their own default dtype.
|
||||||
|
dtype = None
|
||||||
elif dtype == "float16":
|
elif dtype == "float16":
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
elif dtype == "bfloat16":
|
elif dtype == "bfloat16":
|
||||||
|
Loading…
Reference in New Issue
Block a user