From 8074c40473ad6fb67ece92815cad0d4525cf4c3c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 1 Feb 2024 18:35:41 +0000 Subject: [PATCH] adapt awq weights to exllama/gptq kernels --- .../utils/pack_utils.py | 98 +++++++++++++++++++ .../text_generation_server/utils/weights.py | 79 ++++++++++----- 2 files changed, 155 insertions(+), 22 deletions(-) create mode 100644 server/text_generation_server/utils/pack_utils.py diff --git a/server/text_generation_server/utils/pack_utils.py b/server/text_generation_server/utils/pack_utils.py new file mode 100644 index 00000000..d144b3cd --- /dev/null +++ b/server/text_generation_server/utils/pack_utils.py @@ -0,0 +1,98 @@ +import torch +from typing import List + + +AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] +REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def pack(imatrix: torch.Tensor, direction: str = "column"): + """ + Packs a 4-bit integer matrix into a packed 32-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of packing, either "column" or "row" + Returns: + qmatrix (torch.Tensor): packed matrix of integers + """ + imatrix = imatrix.to(torch.int8) + imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow + + shifts = torch.arange(0, 32, 4, device=imatrix.device) + + if direction == "column": + imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1) + + elif direction == "row": + imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1) + + qmatrix = qmatrix.to(torch.int32) + + return qmatrix + + +def unpack(qmatrix: torch.Tensor, direction: str = "column"): + """ + Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix. + Args: + qmatrix (torch.Tensor): matrix of packed integers + direction (str): direction of unpacking, either "column" or "row" + Returns: + imatrix (torch.Tensor): matrix of integers + """ + shifts = torch.arange(0, 32, 4, device=qmatrix.device) + + if direction == "column": + imatrix = torch.bitwise_right_shift( + qmatrix[:, :, None], shifts[None, None, :] + ).view(qmatrix.shape[0], -1) + + elif direction == "row": + imatrix = torch.bitwise_right_shift( + qmatrix[:, None, :], shifts[None, :, None] + ).view(-1, qmatrix.shape[-1]) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow + + return imatrix + + +def apply_order( + imatrix: torch.Tensor, + direction: str = "column", + order: List[int] = AWQ_PACK_ORDER, +): + """ + Applies the order to a 4-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of applying order, either "column" or "row" + order (List[int]): order to apply, default is AWQ_PACK_ORDER + Returns: + imatrix (torch.Tensor): matrix of integers + """ + if direction == "column": + imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape) + elif direction == "row": + imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape) + + return imatrix + + +def fast_awq_to_gptq(qweight, qzeros): + # awq uses column packing for both weights and zeros + izeros = unpack(qzeros, direction="column") + iweights = unpack(qweight, direction="column") + + # Reverse the order of the iweight and izeros tensors + izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) + iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) + # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros) + izeros = izeros - 1 + # exllama uses row packing for weights and column packing for zeros + qzeros = pack(izeros, direction="column") + qweight = pack(iweights, direction="row") + + return qweight, qzeros diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 186733f3..aabd52f4 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -7,6 +7,7 @@ from loguru import logger from huggingface_hub import hf_hub_download import json from text_generation_server.utils.log import log_once +from text_generation_server.utils.pack_utils import fast_awq_to_gptq class Weights: @@ -46,7 +47,6 @@ class Weights: return self._handles[filename] def get_filename(self, tensor_name: str) -> (str, str): - names = [tensor_name] if self.prefix is not None: prefixed = f"{self.prefix}.{tensor_name}" @@ -157,12 +157,20 @@ class Weights: qzeros = self._get_qweight(f"{prefix}.qzeros") scales = self._get_qweight(f"{prefix}.scales") scales = scales.to(dtype=self.dtype) - if quantize == "gptq": + if quantize == "gptq" and self.quant_method == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") else: g_idx = None - bits, groupsize, _ = self._get_gptq_params() + bits, groupsize, _, _ = self._get_gptq_params() + + if quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, + "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", + ) + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: slice_ = self._get_slice(f"{prefix}.weight") @@ -204,7 +212,7 @@ class Weights: [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 ) - if quantize == "gptq": + if quantize == "gptq" and self.quant_method == "gptq": w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) @@ -212,12 +220,20 @@ class Weights: else: g_idx = None - bits, groupsize, desc_act = self._get_gptq_params() + bits, groupsize, desc_act, quant_method = self._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA use_exllama = ( bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act ) + + if quantize == "gptq" and quant_method == "awq": + log_once( + logger.info, + "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", + ) + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -243,7 +259,7 @@ class Weights: def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": use_exllama = True - bits, groupsize, desc_act = self._get_gptq_params() + bits, groupsize, desc_act, quant_method = self._get_gptq_params() if bits != 4: use_exllama = False @@ -252,8 +268,19 @@ class Weights: log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if quant_method == "gptq": + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + else: + g_idx = None + if self.process_group.size() > 1: - g_idx = self.get_tensor(f"{prefix}.g_idx") if g_idx is not None: if ( not torch.equal( @@ -269,13 +296,6 @@ class Weights: # it would require to reorder input activations that are split unto several GPUs use_exllama = False - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA if use_exllama: @@ -289,8 +309,6 @@ class Weights: else: log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if use_exllama and groupsize != -1: qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0) @@ -298,12 +316,19 @@ class Weights: qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") - if use_exllama: + if use_exllama and g_idx is not None: g_idx = g_idx - g_idx[0] + if quant_method == "awq": + log_once( + logger.info, + "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", + ) + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq": - bits, groupsize, _ = self._get_gptq_params() + bits, groupsize, _, _ = self._get_gptq_params() try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) @@ -331,11 +356,12 @@ class Weights: try: bits = self.gptq_bits groupsize = self.gptq_groupsize + quant_method = self.quant_method desc_act = getattr(self, "gptq_desc_act", False) except Exception: raise e - return bits, groupsize, desc_act + return bits, groupsize, desc_act, quant_method def _set_gptq_params(self, model_id, revision): filename = "config.json" @@ -350,7 +376,8 @@ class Weights: data = json.load(f) self.gptq_bits = data["quantization_config"]["bits"] self.gptq_groupsize = data["quantization_config"]["group_size"] - self.gptq_desc_act = data["quantization_config"]["desc_act"] + self.gptq_desc_act = data["quantization_config"].get("desc_act", False) + self.quant_method = data["quantization_config"]["quant_method"] except Exception: filename = "quantize_config.json" try: @@ -364,7 +391,11 @@ class Weights: data = json.load(f) self.gptq_bits = data["bits"] self.gptq_groupsize = data["group_size"] - self.gptq_desc_act = data["desc_act"] + self.gptq_desc_act = data.get("desc_act", False) + if "version" in data and data["version"] == "GEMM": + self.quant_method = "awq" + else: + self.quant_method = "gptq" except Exception: filename = "quant_config.json" try: @@ -378,6 +409,10 @@ class Weights: data = json.load(f) self.gptq_bits = data["w_bit"] self.gptq_groupsize = data["q_group_size"] - self.gptq_desc_act = data["desc_act"] + self.gptq_desc_act = data.get("desc_act", False) + if "version" in data and data["version"] == "GEMM": + self.quant_method = "awq" + else: + self.quant_method = "gptq" except Exception: pass