From acb7e1d465af0ace72410706264fcf1326879f48 Mon Sep 17 00:00:00 2001 From: Abhinav Kulkarni Date: Wed, 13 Sep 2023 11:12:32 +0000 Subject: [PATCH] Added quantize_config.json support for AWQ --- .../utils/awq/quantize/qmodule.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 server/text_generation_server/utils/awq/quantize/qmodule.py diff --git a/server/text_generation_server/utils/awq/quantize/qmodule.py b/server/text_generation_server/utils/awq/quantize/qmodule.py new file mode 100644 index 00000000..e157ce55 --- /dev/null +++ b/server/text_generation_server/utils/awq/quantize/qmodule.py @@ -0,0 +1,101 @@ +# Copied from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py + +import math +import torch +import torch.nn as nn +import awq_inference_engine # with CUDA kernels + + +class ScaledActivation(nn.Module): + def __init__(self, module, scales): + super().__init__() + self.act = module + self.scales = nn.Parameter(scales.data) + + def forward(self, x): + return self.act(x) / self.scales.view(1, 1, -1).to(x.device) + + +class WQLinear(nn.Module): + def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): + super().__init__() + + if w_bit not in [4]: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = in_features + self.out_features = out_features + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else in_features + # quick sanity check (make sure aligment) + assert self.in_features % self.group_size == 0 + assert out_features % (32 // self.w_bit) == 0 + + self.register_buffer('qweight', torch.zeros((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev)) + self.register_buffer('qzeros', torch.zeros((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev)) + self.register_buffer('scales', torch.zeros((in_features // self.group_size, out_features), dtype=torch.float16, device=dev)) + if bias: + self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev)) + else: + self.bias = None + + @classmethod + def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None): + awq_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device) + if init_only: # just prepare for loading sd + return awq_linear + + # need scales and zeros info for real quantization + assert scales is not None and zeros is not None + scale_zeros = zeros * scales + + awq_linear.scales = scales.clone().half() + if linear.bias is not None: + awq_linear.bias = linear.bias.clone().half() + + pack_num = 32 // awq_linear.w_bit + + intweight = [] + for idx in range(awq_linear.in_features): + intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[idx // group_size]) / awq_linear.scales[idx // group_size]).to(torch.int)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.to(dtype=torch.int32) + qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device) + + for col in range(intweight.shape[1] // pack_num): + if awq_linear.w_bit == 4: + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + else: + raise NotImplementedError("Only 4-bit are supported for now.") + for i in range(pack_num): + qweight_col = intweight[:, col * pack_num + order_map[i]] + qweight[:, col] |= qweight_col << (i * awq_linear.w_bit) + awq_linear.qweight = qweight + + zeros = zeros.to(dtype=torch.int32) + qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=zeros.device) + + for col in range(zeros.shape[1] // pack_num): + if awq_linear.w_bit == 4: + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + else: + raise NotImplementedError("Only 4-bit are supported for now.") + for i in range(pack_num): + qzero_col = zeros[:, col * pack_num + order_map[i]] + qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit) + awq_linear.qzeros = qzeros + + return awq_linear + + @torch.no_grad() + def forward(self, x): + out_shape = x.shape[:-1] + (self.out_features, ) + out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format( + self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size + )