diff --git a/server/text_generation_server/utils/awq/pack_utils.py b/server/text_generation_server/utils/awq/pack_utils.py index 9b15e1db..d144b3cd 100644 --- a/server/text_generation_server/utils/awq/pack_utils.py +++ b/server/text_generation_server/utils/awq/pack_utils.py @@ -15,10 +15,10 @@ def pack(imatrix: torch.Tensor, direction: str = "column"): Returns: qmatrix (torch.Tensor): packed matrix of integers """ - shifts = torch.arange(0, 32, 4, device=imatrix.device) - imatrix = imatrix.to(torch.int8) imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow + + shifts = torch.arange(0, 32, 4, device=imatrix.device) if direction == "column": imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) @@ -59,54 +59,6 @@ def unpack(qmatrix: torch.Tensor, direction: str = "column"): return imatrix -def quantize(fmatrix, scales, zeros, group_size): - """ - Quantizes a matrix of 16-bit floats into a matrix of 4-bit integers. - Args: - fmatrix (torch.Tensor): matrix of 16-bit floats - scales (torch.Tensor): matrix of 16-bit floats - zeros (torch.Tensor): matrix of 4-bit integers - group_size (int): group size - Returns: - imatrix (torch.Tensor): matrix of 4-bit integers - """ - zeros = zeros.to(torch.int8) & 0x0F - - imatrix = torch.round( - ( - fmatrix / scales.repeat_interleave(group_size, dim=0) - + zeros.repeat_interleave(group_size, dim=0) - ) - ) - - imatrix = imatrix.to(torch.int8) & 0x0F - - return imatrix - - -def dequantize(imatrix, scales, zeros, group_size): - """ - Dequantizes a 4-bit integer matrix into a float matrix. - Args: - imatrix (torch.Tensor): matrix of 4-bit integers - scales (torch.Tensor): matrix of 16-bit floats - zeros (torch.Tensor): matrix of 4-bit integers - group_size (int): group size - Returns: - fmatrix (torch.Tensor): matrix of 16-bit floats - """ - zeros = zeros.to(torch.int8) & 0x0F - imatrix = imatrix.to(torch.int8) & 0x0F - - fmatrix = ( - imatrix - zeros.repeat_interleave(group_size, dim=0) - ) * scales.repeat_interleave(group_size, dim=0) - - fmatrix = fmatrix.to(torch.float16) - - return fmatrix - - def apply_order( imatrix: torch.Tensor, direction: str = "column", @@ -129,7 +81,7 @@ def apply_order( return imatrix -def fast_awq_to_exllama(qweight, qzeros): +def fast_awq_to_gptq(qweight, qzeros): # awq uses column packing for both weights and zeros izeros = unpack(qzeros, direction="column") iweights = unpack(qweight, direction="column") @@ -137,7 +89,7 @@ def fast_awq_to_exllama(qweight, qzeros): # Reverse the order of the iweight and izeros tensors izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) - # Subtract 1 from the izeros tensor (exllama adds 1 during inference) + # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros) izeros = izeros - 1 # exllama uses row packing for weights and column packing for zeros qzeros = pack(izeros, direction="column") diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 782744ed..7f20081d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -25,7 +25,8 @@ HAS_AWQ = True try: from text_generation_server.utils.awq.quantize.qmodule import WQLinear except ImportError: - from text_generation_server.utils.awq.pack_utils import fast_awq_to_exllama + from text_generation_server.utils.awq.pack_utils import fast_awq_to_gptq + HAS_AWQ = False try: @@ -360,10 +361,13 @@ def get_linear(weight, bias, quantize): bias=bias is not None, ) elif HAS_EXLLAMA: - qweight, qzeros = fast_awq_to_exllama(qweight, qzeros) + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) linear = ExllamaQuantLinear( qweight, qzeros, scales, None, bias, bits, groupsize ) + else: + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + linear = QuantLinear(qweight, qzeros, scales, None, bias, bits, groupsize) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear