From 646ab282855e39ab674619d13bf609791d123cb5 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 1 Feb 2024 19:37:02 +0000 Subject: [PATCH] typing --- server/text_generation_server/utils/weights.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index aabd52f4..875ac464 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -154,17 +154,18 @@ class Weights: f"Cannot load `{quantize}` weight, make sure the model is already quantized." ) + bits, groupsize, _, quant_method = self._get_gptq_params() + qzeros = self._get_qweight(f"{prefix}.qzeros") scales = self._get_qweight(f"{prefix}.scales") scales = scales.to(dtype=self.dtype) - if quantize == "gptq" and self.quant_method == "gptq": + + if quantize == "gptq" and quant_method == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") else: g_idx = None - bits, groupsize, _, _ = self._get_gptq_params() - - if quantize == "gptq" and self.quant_method == "awq": + if quantize == "gptq" and quant_method == "awq": log_once( logger.info, "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", @@ -212,7 +213,9 @@ class Weights: [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 ) - if quantize == "gptq" and self.quant_method == "gptq": + bits, groupsize, desc_act, quant_method = self._get_gptq_params() + + if quantize == "gptq" and quant_method == "gptq": w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) @@ -220,7 +223,6 @@ class Weights: else: g_idx = None - bits, groupsize, desc_act, quant_method = self._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA use_exllama = ( @@ -347,7 +349,7 @@ class Weights: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight - def _get_gptq_params(self) -> Tuple[int, int, int]: + def _get_gptq_params(self) -> Tuple[int, int, int, str]: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item()