diff --git a/server/text_generation_server/utils/pack_utils.py b/server/text_generation_server/utils/awq/conversion_utils.py similarity index 94% rename from server/text_generation_server/utils/pack_utils.py rename to server/text_generation_server/utils/awq/conversion_utils.py index d144b3cd..b19eafbb 100644 --- a/server/text_generation_server/utils/pack_utils.py +++ b/server/text_generation_server/utils/awq/conversion_utils.py @@ -15,10 +15,9 @@ def pack(imatrix: torch.Tensor, direction: str = "column"): Returns: qmatrix (torch.Tensor): packed matrix of integers """ - imatrix = imatrix.to(torch.int8) - imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow - - shifts = torch.arange(0, 32, 4, device=imatrix.device) + shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow if direction == "column": imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index 34895c01..8ad0dd80 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -182,7 +182,7 @@ try: ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) & maxq # add 1 and avoid overflow + zeros = (zeros + 1) & maxq # eventually avoid overflow a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated @@ -251,17 +251,7 @@ class QuantLinear(nn.Module): self.register_buffer("qweight", qweight) self.register_buffer("qzeros", qzeros) self.register_buffer("scales", scales) - if g_idx is not None: - self.register_buffer("g_idx", g_idx) - else: - self.register_buffer( - "g_idx", - torch.tensor( - [i // groupsize for i in range(qweight.shape[0] * 32 // bits)], - device=qweight.device, - dtype=torch.int32, - ), - ) + self.register_buffer("g_idx", g_idx) if bias is not None: self.register_buffer("bias", bias) else: diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 875ac464..759ea602 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -7,7 +7,6 @@ from loguru import logger from huggingface_hub import hf_hub_download import json from text_generation_server.utils.log import log_once -from text_generation_server.utils.pack_utils import fast_awq_to_gptq class Weights: @@ -162,15 +161,22 @@ class Weights: if quantize == "gptq" and quant_method == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") - else: - g_idx = None - - if quantize == "gptq" and quant_method == "awq": + elif quantize == "gptq" and quant_method == "awq": log_once( logger.info, - "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", + "Converting AWQ weights to Exllama/GPTQ packing format, " + "in order used with Exllama/GPTQ kernels.", ) + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = torch.zeros( + (qweight.shape[0] * 32 // bits), + dtype=torch.int32, + device=qweight.device, + ) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: @@ -220,8 +226,22 @@ class Weights: for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) g_idx = w[0] - else: - g_idx = None + elif quantize == "gptq" and quant_method == "awq": + log_once( + logger.info, + "Converting AWQ weights to Exllama/GPTQ packing format, " + "in order used with Exllama/GPTQ kernels.", + ) + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = torch.zeros( + (qweight.shape[0] * 32 // bits), + dtype=torch.int32, + device=qweight.device, + ) from text_generation_server.utils.layers import HAS_EXLLAMA @@ -229,13 +249,6 @@ class Weights: bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act ) - if quantize == "gptq" and quant_method == "awq": - log_once( - logger.info, - "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", - ) - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -279,7 +292,7 @@ class Weights: if quant_method == "gptq": g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - else: + elif quant_method == "awq": g_idx = None if self.process_group.size() > 1: @@ -324,9 +337,19 @@ class Weights: if quant_method == "awq": log_once( logger.info, - "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", + "Converting AWQ weights to Exllama/GPTQ packing format, " + "in order used with Exllama/GPTQ kernels.", ) + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = torch.zeros( + (qweight.shape[0] * 32 // bits), + dtype=torch.int32, + device=qweight.device, + ) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq": @@ -353,13 +376,14 @@ class Weights: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() + quant_method = "gptq" desc_act = False except (SafetensorError, RuntimeError) as e: try: bits = self.gptq_bits groupsize = self.gptq_groupsize - quant_method = self.quant_method desc_act = getattr(self, "gptq_desc_act", False) + quant_method = getattr(self, "quant_method", "gptq") except Exception: raise e @@ -378,8 +402,8 @@ class Weights: data = json.load(f) self.gptq_bits = data["quantization_config"]["bits"] self.gptq_groupsize = data["quantization_config"]["group_size"] - self.gptq_desc_act = data["quantization_config"].get("desc_act", False) self.quant_method = data["quantization_config"]["quant_method"] + self.gptq_desc_act = data["quantization_config"]["desc_act"] except Exception: filename = "quantize_config.json" try: @@ -393,11 +417,11 @@ class Weights: data = json.load(f) self.gptq_bits = data["bits"] self.gptq_groupsize = data["group_size"] - self.gptq_desc_act = data.get("desc_act", False) if "version" in data and data["version"] == "GEMM": self.quant_method = "awq" else: self.quant_method = "gptq" + self.gptq_desc_act = data["desc_act"] except Exception: filename = "quant_config.json" try: @@ -411,10 +435,10 @@ class Weights: data = json.load(f) self.gptq_bits = data["w_bit"] self.gptq_groupsize = data["q_group_size"] - self.gptq_desc_act = data.get("desc_act", False) if "version" in data and data["version"] == "GEMM": self.quant_method = "awq" else: self.quant_method = "gptq" + self.gptq_desc_act = data["desc_act"] except Exception: pass