This commit is contained in:
IlyasMoutawwakil 2024-02-01 19:37:02 +00:00
parent fb59c56215
commit bcdb02e41a

View File

@ -154,17 +154,18 @@ class Weights:
f"Cannot load `{quantize}` weight, make sure the model is already quantized." f"Cannot load `{quantize}` weight, make sure the model is already quantized."
) )
bits, groupsize, _, quant_method = self._get_gptq_params()
qzeros = self._get_qweight(f"{prefix}.qzeros") qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales") scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype) scales = scales.to(dtype=self.dtype)
if quantize == "gptq" and self.quant_method == "gptq":
if quantize == "gptq" and quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx") g_idx = self.get_tensor(f"{prefix}.g_idx")
else: else:
g_idx = None g_idx = None
bits, groupsize, _, _ = self._get_gptq_params() if quantize == "gptq" and quant_method == "awq":
if quantize == "gptq" and self.quant_method == "awq":
log_once( log_once(
logger.info, logger.info,
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
@ -212,7 +213,9 @@ class Weights:
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
) )
if quantize == "gptq" and self.quant_method == "gptq": bits, groupsize, desc_act, quant_method = self._get_gptq_params()
if quantize == "gptq" and quant_method == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]: for w2 in w[1:]:
torch.testing.assert_close(w2, w[0]) torch.testing.assert_close(w2, w[0])
@ -220,7 +223,6 @@ class Weights:
else: else:
g_idx = None g_idx = None
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = ( use_exllama = (
@ -347,7 +349,7 @@ class Weights:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight return weight
def _get_gptq_params(self) -> Tuple[int, int, int]: def _get_gptq_params(self) -> Tuple[int, int, int, str]:
try: try:
bits = self.get_tensor("gptq_bits").item() bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item() groupsize = self.get_tensor("gptq_groupsize").item()