Fix GPTQ for models which do not have float16 at the default dtype

Before this change GPTQ models would not work if the model's default
data type is not `float16`. For example, Gemma GPTQ models would fail
because the default dtype of Gemma is `bfloat16`. There are two issues:

1. If the default `dtype` is not `float16`, the quantizer's `float16`
   parameters get converted to that dtype. The kernels cannot deal
   with non-`float16` types.

This change resolves this issue by excluding quantizer parameters
from data type conversions.

2. Quantized models will typically have `float16` parameters. However,
   the default dtype was set to model's default. So, if a quantized
   Gemma uses `float16`, all parameters are converted to `bfloat16` since
   it is the model's default. This fails in quantized gemm, because it
   expects `float16` arguments.

This is resolved by setting the dtype of gptq/awq-quantized models to
`float16`. (We cannot use `torch_dtype` from the config, because it
often does not correspond to the dtype of the parameters.)
This commit is contained in:
Daniël de Kok 2024-05-24 19:01:07 +00:00
parent 9231098f3a
commit b9b5051abc
7 changed files with 642 additions and 47 deletions

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.640625,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4296875,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4453125,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.8632812,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1328125,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.76660156,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3837891,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -1.9746094,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4189453,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.34375,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.8852539,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.65625,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.3671875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 604,
"logprob": -0.36938477,
"special": false,
"text": " for"
},
{
"id": 235248,
"logprob": -1.8046875,
"special": false,
"text": " "
},
{
"id": 235274,
"logprob": -0.46240234,
"special": false,
"text": "1"
},
{
"id": 235284,
"logprob": -1.7460938,
"special": false,
"text": "2"
},
{
"id": 235265,
"logprob": -1.9443359,
"special": false,
"text": "."
},
{
"id": 235284,
"logprob": -1.4550781,
"special": false,
"text": "2"
},
{
"id": 235308,
"logprob": -1.0205078,
"special": false,
"text": "5"
},
{
"id": 235290,
"logprob": -1.0283203,
"special": false,
"text": "-"
},
{
"id": 235274,
"logprob": -1.2783203,
"special": false,
"text": "1"
},
{
"id": 235284,
"logprob": 0.0,
"special": false,
"text": "2"
}
],
"top_tokens": null
},
"generated_text": "Test request for 12.25-12"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.359375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4277344,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4394531,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.8613281,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1523438,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.76220703,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3642578,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -2.0175781,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4238281,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.328125,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.8881836,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4238281,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4453125,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.859375,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1445312,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.7631836,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3642578,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -1.9960938,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4179688,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.3359375,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.8847656,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.640625,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.3671875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4257812,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4453125,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.8789062,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1367188,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.76171875,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3515625,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -1.9873047,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4169922,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.3320312,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.8930664,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.359375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4179688,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4492188,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.8574219,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1445312,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.7519531,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3623047,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -1.9707031,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4267578,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.3359375,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.88427734,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
}
]

View File

@ -0,0 +1,62 @@
import pytest
@pytest.fixture(scope="module")
def flash_gemma_gptq_handle(launcher):
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_gemma_gptq(flash_gemma_gptq_handle):
await flash_gemma_gptq_handle.health(300)
return flash_gemma_gptq_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot):
response = await flash_gemma_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot):
response = await flash_gemma_gptq.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq_load(
flash_gemma_gptq, generate_load, response_snapshot
):
responses = await generate_load(
flash_gemma_gptq, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -263,9 +263,13 @@ def get_model(
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
if dtype is None: if dtype is None:
# Keep it as default for now and let if quantize in ["awq", "gptq"]:
# every model resolve their own default dtype. # These quantizers only work with float16 params.
dtype = None dtype = torch.float16
else:
# Keep it as default for now and let
# every model resolve their own default dtype.
dtype = None
elif dtype == "float16": elif dtype == "float16":
dtype = torch.float16 dtype = torch.float16
elif dtype == "bfloat16": elif dtype == "bfloat16":

View File

@ -78,7 +78,7 @@ def _load_multi_mqa_gptq(
quant_method, quant_method,
) = weights._get_gptq_params() ) = weights._get_gptq_params()
if quant_method == "gptq": if quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx", to_dtype=False)
g_idx = g_idx.to(device=weights.device) g_idx = g_idx.to(device=weights.device)
elif quant_method == "awq": elif quant_method == "awq":
g_idx = None g_idx = None

View File

@ -71,19 +71,19 @@ class Weights:
def get_shape(self, tensor_name: str): def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape() return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True): def get_tensor(
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert if to_dtype:
# u4 which are disguised as int32
if tensor.dtype not in [torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
if to_device: if to_device:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_partial_sharded(self, tensor_name: str, dim: int): def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype: bool = True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
@ -101,14 +101,12 @@ class Weights:
tensor = slice_[:, start:stop] tensor = slice_[:, start:stop]
else: else:
raise NotImplementedError("Let's make that generic when needed") raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert if to_dtype:
# u4 which are disguised as int32
if tensor.dtype != torch.int32:
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int): def get_sharded(self, tensor_name: str, dim: int, to_dtype: bool = True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
@ -117,7 +115,7 @@ class Weights:
assert ( assert (
size % world_size == 0 size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards" ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim) return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype)
def _get_qweight(self, name: str): def _get_qweight(self, name: str):
slice_ = self._get_slice(name) slice_ = self._get_slice(name)
@ -163,10 +161,9 @@ class Weights:
qzeros = self._get_qweight(f"{prefix}.qzeros") qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales") scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype)
if quantize == "gptq" and quant_method == "gptq": if quantize == "gptq" and quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx") g_idx = self.get_tensor(f"{prefix}.g_idx", to_dtype=False)
elif quantize == "gptq" and quant_method == "awq": elif quantize == "gptq" and quant_method == "awq":
log_once( log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format." logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
@ -211,7 +208,11 @@ class Weights:
if quantize in ["gptq", "awq"]: if quantize in ["gptq", "awq"]:
try: try:
qweight = torch.cat( qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 [
self.get_sharded(f"{p}.qweight", dim=1, to_dtype=False)
for p in prefixes
],
dim=1,
) )
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
@ -219,10 +220,18 @@ class Weights:
) )
qzeros = torch.cat( qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 [
self.get_sharded(f"{p}.qzeros", dim=1, to_dtype=False)
for p in prefixes
],
dim=1,
) )
scales = torch.cat( scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 [
self.get_sharded(f"{p}.scales", dim=1, to_dtype=False)
for p in prefixes
],
dim=1,
) )
bits, groupsize, desc_act, quant_method = self._get_gptq_params() bits, groupsize, desc_act, quant_method = self._get_gptq_params()
@ -234,7 +243,7 @@ class Weights:
) )
if quantize == "gptq" and quant_method == "gptq": if quantize == "gptq" and quant_method == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] w = [self.get_tensor(f"{p}.g_idx", to_dtype=False) for p in prefixes]
for w2 in w[1:]: for w2 in w[1:]:
torch.testing.assert_close(w2, w[0]) torch.testing.assert_close(w2, w[0])
g_idx = w[0] g_idx = w[0]
@ -265,22 +274,6 @@ class Weights:
weight = torch.cat(w, dim=dim) weight = torch.cat(w, dim=dim)
return weight return weight
def get_tensor_shard(self, var, dim):
world_size = self.process_group.size()
rank = self.process_group.rank()
block_size = var.size()[dim] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
if dim == 0:
tensor = var[start:stop]
elif dim == 1:
tensor = var[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_multi_weights_row(self, prefix: str, quantize: str): def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq": if quantize == "gptq":
use_exllama = True use_exllama = True
@ -294,14 +287,14 @@ class Weights:
use_exllama = False use_exllama = False
try: try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0) qweight = self.get_sharded(f"{prefix}.qweight", dim=0, to_dtype=False)
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
) )
if quant_method == "gptq": if quant_method == "gptq":
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0, to_dtype=False)
elif quant_method == "awq": elif quant_method == "awq":
g_idx = None g_idx = None
@ -335,11 +328,11 @@ class Weights:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and groupsize != -1: if use_exllama and groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0, to_dtype=False)
scales = self.get_sharded(f"{prefix}.scales", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0, to_dtype=False)
else: else:
qzeros = self.get_tensor(f"{prefix}.qzeros") qzeros = self.get_tensor(f"{prefix}.qzeros", to_dtype=False)
scales = self.get_tensor(f"{prefix}.scales") scales = self.get_tensor(f"{prefix}.scales", to_dtype=False)
if use_exllama and g_idx is not None: if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0] g_idx = g_idx - g_idx[0]
@ -368,14 +361,14 @@ class Weights:
bits, groupsize, _, _ = self._get_gptq_params() bits, groupsize, _, _ = self._get_gptq_params()
try: try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0) qweight = self.get_sharded(f"{prefix}.qweight", dim=0, to_dtype=False)
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized" "Cannot load `awq` weight, make sure the model is already quantized"
) )
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0, to_dtype=False)
scales = self.get_sharded(f"{prefix}.scales", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0, to_dtype=False)
g_idx = None g_idx = None
use_exllama = False use_exllama = False
@ -386,8 +379,8 @@ class Weights:
def _get_gptq_params(self) -> Tuple[int, int, int, str]: def _get_gptq_params(self) -> Tuple[int, int, int, str]:
try: try:
bits = self.get_tensor("gptq_bits").item() bits = self.get_tensor("gptq_bits", to_dtype=False).item()
groupsize = self.get_tensor("gptq_groupsize").item() groupsize = self.get_tensor("gptq_groupsize", to_dtype=False).item()
desc_act = False desc_act = False
quant_method = "gptq" quant_method = "gptq"
except (SafetensorError, RuntimeError) as e: except (SafetensorError, RuntimeError) as e: