From bc157af9b0a73163ffd4998a0788c970a9334722 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 8 Feb 2024 16:05:09 +0100 Subject: [PATCH] generate g_idx only for triton kernel --- .../text_generation_server/utils/weights.py | 50 +++++++++++-------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 767a23b2..8f7e1f10 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -220,6 +220,12 @@ class Weights: bits, groupsize, desc_act, quant_method = self._get_gptq_params() + from text_generation_server.utils.layers import HAS_EXLLAMA + + use_exllama = ( + bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act + ) + if quantize == "gptq" and quant_method == "gptq": w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: @@ -234,19 +240,18 @@ class Weights: ) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = ( - torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) - // groupsize - ).to(dtype=torch.int32) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // bits), device=qweight.device + ) + // groupsize + ).to(dtype=torch.int32) else: g_idx = None - from text_generation_server.utils.layers import HAS_EXLLAMA - - use_exllama = ( - bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act - ) - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -341,10 +346,15 @@ class Weights: ) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = ( - torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) - // groupsize - ).to(dtype=torch.int32) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // bits), device=qweight.device + ) + // groupsize + ).to(dtype=torch.int32) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq": @@ -371,8 +381,8 @@ class Weights: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() - quant_method = "gptq" desc_act = False + quant_method = "gptq" except (SafetensorError, RuntimeError) as e: try: bits = self.gptq_bits @@ -397,8 +407,8 @@ class Weights: data = json.load(f) self.gptq_bits = data["quantization_config"]["bits"] self.gptq_groupsize = data["quantization_config"]["group_size"] - self.quant_method = data["quantization_config"]["quant_method"] self.gptq_desc_act = data["quantization_config"]["desc_act"] + self.quant_method = data["quantization_config"]["quant_method"] except Exception: filename = "quantize_config.json" try: @@ -412,11 +422,9 @@ class Weights: data = json.load(f) self.gptq_bits = data["bits"] self.gptq_groupsize = data["group_size"] + self.gptq_desc_act = data["desc_act"] if "version" in data and data["version"] == "GEMM": self.quant_method = "awq" - else: - self.quant_method = "gptq" - self.gptq_desc_act = data["desc_act"] except Exception: filename = "quant_config.json" try: @@ -430,10 +438,8 @@ class Weights: data = json.load(f) self.gptq_bits = data["w_bit"] self.gptq_groupsize = data["q_group_size"] + self.gptq_desc_act = data["desc_act"] if "version" in data and data["version"] == "GEMM": self.quant_method = "awq" - else: - self.quant_method = "gptq" - self.gptq_desc_act = data["desc_act"] except Exception: pass