This commit is contained in:
IlyasMoutawwakil 2024-02-01 19:37:02 +00:00
parent fb59c56215
commit bcdb02e41a

View File

@ -154,17 +154,18 @@ class Weights:
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)
bits, groupsize, _, quant_method = self._get_gptq_params()
qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype)
if quantize == "gptq" and self.quant_method == "gptq":
if quantize == "gptq" and quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
bits, groupsize, _, _ = self._get_gptq_params()
if quantize == "gptq" and self.quant_method == "awq":
if quantize == "gptq" and quant_method == "awq":
log_once(
logger.info,
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
@ -212,7 +213,9 @@ class Weights:
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
if quantize == "gptq" and self.quant_method == "gptq":
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
if quantize == "gptq" and quant_method == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
@ -220,7 +223,6 @@ class Weights:
else:
g_idx = None
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = (
@ -347,7 +349,7 @@ class Weights:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_params(self) -> Tuple[int, int, int]:
def _get_gptq_params(self) -> Tuple[int, int, int, str]:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()