mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fix GPTQ for models which do not have float16 at the default dtype
Before this change GPTQ models would not work if the model's default data type is not `float16`. For example, Gemma GPTQ models would fail because the default dtype of Gemma is `bfloat16`. There are two issues: 1. If the default `dtype` is not `float16`, the quantizer's `float16` parameters get converted to that dtype. The kernels cannot deal with non-`float16` types. This change resolves this issue by excluding quantizer parameters from data type conversions. 2. Quantized models will typically have `float16` parameters. However, the default dtype was set to model's default. So, if a quantized Gemma uses `float16`, all parameters are converted to `bfloat16` since it is the model's default. This fails in quantized gemm, because it expects `float16` arguments. This is resolved by setting the dtype of gptq/awq-quantized models to `float16`. (We cannot use `torch_dtype` from the config, because it often does not correspond to the dtype of the parameters.)
This commit is contained in:
parent
9231098f3a
commit
b9b5051abc
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.640625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.34375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4296875,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4453125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.8632812,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1328125,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.76660156,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3837891,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.9746094,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4189453,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.34375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.8852539,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.65625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.3671875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -0.36938477,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.8046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235274,
|
||||||
|
"logprob": -0.46240234,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235284,
|
||||||
|
"logprob": -1.7460938,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235265,
|
||||||
|
"logprob": -1.9443359,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235284,
|
||||||
|
"logprob": -1.4550781,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235308,
|
||||||
|
"logprob": -1.0205078,
|
||||||
|
"special": false,
|
||||||
|
"text": "5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235290,
|
||||||
|
"logprob": -1.0283203,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235274,
|
||||||
|
"logprob": -1.2783203,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235284,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request for 12.25-12"
|
||||||
|
}
|
@ -0,0 +1,358 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.6484375,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.359375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4277344,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4394531,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.8613281,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1523438,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.76220703,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3642578,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -2.0175781,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4238281,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.328125,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.8881836,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.6484375,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.34375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4238281,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4453125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.859375,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.7631836,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3642578,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.9960938,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4179688,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.3359375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.8847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.640625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.3671875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4257812,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4453125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.8789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1367188,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.76171875,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3515625,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.9873047,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4169922,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.3320312,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.8930664,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.6484375,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.359375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4179688,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4492188,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.8574219,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.7519531,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3623047,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.9707031,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4267578,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.3359375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.88427734,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
}
|
||||||
|
]
|
62
integration-tests/models/test_flash_gemma_gptq.py
Normal file
62
integration-tests/models/test_flash_gemma_gptq.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_gemma_gptq_handle(launcher):
|
||||||
|
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_gemma_gptq(flash_gemma_gptq_handle):
|
||||||
|
await flash_gemma_gptq_handle.health(300)
|
||||||
|
return flash_gemma_gptq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot):
|
||||||
|
response = await flash_gemma_gptq.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot):
|
||||||
|
response = await flash_gemma_gptq.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma_gptq_load(
|
||||||
|
flash_gemma_gptq, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_gemma_gptq, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -263,9 +263,13 @@ def get_model(
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
# Keep it as default for now and let
|
if quantize in ["awq", "gptq"]:
|
||||||
# every model resolve their own default dtype.
|
# These quantizers only work with float16 params.
|
||||||
dtype = None
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
# Keep it as default for now and let
|
||||||
|
# every model resolve their own default dtype.
|
||||||
|
dtype = None
|
||||||
elif dtype == "float16":
|
elif dtype == "float16":
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
elif dtype == "bfloat16":
|
elif dtype == "bfloat16":
|
||||||
|
@ -78,7 +78,7 @@ def _load_multi_mqa_gptq(
|
|||||||
quant_method,
|
quant_method,
|
||||||
) = weights._get_gptq_params()
|
) = weights._get_gptq_params()
|
||||||
if quant_method == "gptq":
|
if quant_method == "gptq":
|
||||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx", to_dtype=False)
|
||||||
g_idx = g_idx.to(device=weights.device)
|
g_idx = g_idx.to(device=weights.device)
|
||||||
elif quant_method == "awq":
|
elif quant_method == "awq":
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
@ -71,19 +71,19 @@ class Weights:
|
|||||||
def get_shape(self, tensor_name: str):
|
def get_shape(self, tensor_name: str):
|
||||||
return self._get_slice(tensor_name).get_shape()
|
return self._get_slice(tensor_name).get_shape()
|
||||||
|
|
||||||
def get_tensor(self, tensor_name: str, to_device=True):
|
def get_tensor(
|
||||||
|
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
|
||||||
|
):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
tensor = f.get_tensor(tensor_name)
|
tensor = f.get_tensor(tensor_name)
|
||||||
# Special case for gptq which shouldn't convert
|
if to_dtype:
|
||||||
# u4 which are disguised as int32
|
|
||||||
if tensor.dtype not in [torch.int32, torch.int64]:
|
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
if to_device:
|
if to_device:
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype: bool = True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
@ -101,14 +101,12 @@ class Weights:
|
|||||||
tensor = slice_[:, start:stop]
|
tensor = slice_[:, start:stop]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Let's make that generic when needed")
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
# Special case for gptq which shouldn't convert
|
if to_dtype:
|
||||||
# u4 which are disguised as int32
|
|
||||||
if tensor.dtype != torch.int32:
|
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_sharded(self, tensor_name: str, dim: int):
|
def get_sharded(self, tensor_name: str, dim: int, to_dtype: bool = True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
@ -117,7 +115,7 @@ class Weights:
|
|||||||
assert (
|
assert (
|
||||||
size % world_size == 0
|
size % world_size == 0
|
||||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
return self.get_partial_sharded(tensor_name, dim)
|
return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype)
|
||||||
|
|
||||||
def _get_qweight(self, name: str):
|
def _get_qweight(self, name: str):
|
||||||
slice_ = self._get_slice(name)
|
slice_ = self._get_slice(name)
|
||||||
@ -163,10 +161,9 @@ class Weights:
|
|||||||
|
|
||||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||||
scales = self._get_qweight(f"{prefix}.scales")
|
scales = self._get_qweight(f"{prefix}.scales")
|
||||||
scales = scales.to(dtype=self.dtype)
|
|
||||||
|
|
||||||
if quantize == "gptq" and quant_method == "gptq":
|
if quantize == "gptq" and quant_method == "gptq":
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
g_idx = self.get_tensor(f"{prefix}.g_idx", to_dtype=False)
|
||||||
elif quantize == "gptq" and quant_method == "awq":
|
elif quantize == "gptq" and quant_method == "awq":
|
||||||
log_once(
|
log_once(
|
||||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
@ -211,7 +208,11 @@ class Weights:
|
|||||||
if quantize in ["gptq", "awq"]:
|
if quantize in ["gptq", "awq"]:
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat(
|
qweight = torch.cat(
|
||||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
[
|
||||||
|
self.get_sharded(f"{p}.qweight", dim=1, to_dtype=False)
|
||||||
|
for p in prefixes
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
)
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -219,10 +220,18 @@ class Weights:
|
|||||||
)
|
)
|
||||||
|
|
||||||
qzeros = torch.cat(
|
qzeros = torch.cat(
|
||||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
[
|
||||||
|
self.get_sharded(f"{p}.qzeros", dim=1, to_dtype=False)
|
||||||
|
for p in prefixes
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
)
|
)
|
||||||
scales = torch.cat(
|
scales = torch.cat(
|
||||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
[
|
||||||
|
self.get_sharded(f"{p}.scales", dim=1, to_dtype=False)
|
||||||
|
for p in prefixes
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||||
@ -234,7 +243,7 @@ class Weights:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if quantize == "gptq" and quant_method == "gptq":
|
if quantize == "gptq" and quant_method == "gptq":
|
||||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
w = [self.get_tensor(f"{p}.g_idx", to_dtype=False) for p in prefixes]
|
||||||
for w2 in w[1:]:
|
for w2 in w[1:]:
|
||||||
torch.testing.assert_close(w2, w[0])
|
torch.testing.assert_close(w2, w[0])
|
||||||
g_idx = w[0]
|
g_idx = w[0]
|
||||||
@ -265,22 +274,6 @@ class Weights:
|
|||||||
weight = torch.cat(w, dim=dim)
|
weight = torch.cat(w, dim=dim)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_tensor_shard(self, var, dim):
|
|
||||||
world_size = self.process_group.size()
|
|
||||||
rank = self.process_group.rank()
|
|
||||||
block_size = var.size()[dim] // world_size
|
|
||||||
start = rank * block_size
|
|
||||||
stop = (rank + 1) * block_size
|
|
||||||
if dim == 0:
|
|
||||||
tensor = var[start:stop]
|
|
||||||
elif dim == 1:
|
|
||||||
tensor = var[:, start:stop]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Let's make that generic when needed")
|
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
|
||||||
tensor = tensor.to(device=self.device)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
@ -294,14 +287,14 @@ class Weights:
|
|||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0, to_dtype=False)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
)
|
)
|
||||||
|
|
||||||
if quant_method == "gptq":
|
if quant_method == "gptq":
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0, to_dtype=False)
|
||||||
elif quant_method == "awq":
|
elif quant_method == "awq":
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
@ -335,11 +328,11 @@ class Weights:
|
|||||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||||
|
|
||||||
if use_exllama and groupsize != -1:
|
if use_exllama and groupsize != -1:
|
||||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0, to_dtype=False)
|
||||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
scales = self.get_sharded(f"{prefix}.scales", dim=0, to_dtype=False)
|
||||||
else:
|
else:
|
||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
qzeros = self.get_tensor(f"{prefix}.qzeros", to_dtype=False)
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
scales = self.get_tensor(f"{prefix}.scales", to_dtype=False)
|
||||||
|
|
||||||
if use_exllama and g_idx is not None:
|
if use_exllama and g_idx is not None:
|
||||||
g_idx = g_idx - g_idx[0]
|
g_idx = g_idx - g_idx[0]
|
||||||
@ -368,14 +361,14 @@ class Weights:
|
|||||||
bits, groupsize, _, _ = self._get_gptq_params()
|
bits, groupsize, _, _ = self._get_gptq_params()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0, to_dtype=False)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||||
)
|
)
|
||||||
|
|
||||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0, to_dtype=False)
|
||||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
scales = self.get_sharded(f"{prefix}.scales", dim=0, to_dtype=False)
|
||||||
g_idx = None
|
g_idx = None
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
@ -386,8 +379,8 @@ class Weights:
|
|||||||
|
|
||||||
def _get_gptq_params(self) -> Tuple[int, int, int, str]:
|
def _get_gptq_params(self) -> Tuple[int, int, int, str]:
|
||||||
try:
|
try:
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits", to_dtype=False).item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize", to_dtype=False).item()
|
||||||
desc_act = False
|
desc_act = False
|
||||||
quant_method = "gptq"
|
quant_method = "gptq"
|
||||||
except (SafetensorError, RuntimeError) as e:
|
except (SafetensorError, RuntimeError) as e:
|
||||||
|
Loading…
Reference in New Issue
Block a user