diff --git a/server/text_generation_server/utils/awq/quantize/qmodule.py b/server/text_generation_server/utils/awq/quantize/qmodule.py index e157ce55..fb1adf5c 100644 --- a/server/text_generation_server/utils/awq/quantize/qmodule.py +++ b/server/text_generation_server/utils/awq/quantize/qmodule.py @@ -1,4 +1,4 @@ -# Copied from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py +# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py import math import torch @@ -17,77 +17,29 @@ class ScaledActivation(nn.Module): class WQLinear(nn.Module): - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): + def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): super().__init__() if w_bit not in [4]: raise NotImplementedError("Only 4-bit are supported for now.") - self.in_features = in_features - self.out_features = out_features + self.in_features = qweight.shape[0] + self.out_features = qweight.shape[1] * 32 // w_bit + self.w_bit = w_bit - self.group_size = group_size if group_size != -1 else in_features + self.group_size = group_size if group_size != -1 else self.in_features # quick sanity check (make sure aligment) assert self.in_features % self.group_size == 0 - assert out_features % (32 // self.w_bit) == 0 + assert self.out_features % (32 // self.w_bit) == 0 - self.register_buffer('qweight', torch.zeros((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev)) - self.register_buffer('qzeros', torch.zeros((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev)) - self.register_buffer('scales', torch.zeros((in_features // self.group_size, out_features), dtype=torch.float16, device=dev)) + self.register_buffer('qweight', qweight) + self.register_buffer('qzeros', qzeros) + self.register_buffer('scales', scales) if bias: - self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev)) + self.register_buffer('bias', bias) else: self.bias = None - @classmethod - def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None): - awq_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device) - if init_only: # just prepare for loading sd - return awq_linear - - # need scales and zeros info for real quantization - assert scales is not None and zeros is not None - scale_zeros = zeros * scales - - awq_linear.scales = scales.clone().half() - if linear.bias is not None: - awq_linear.bias = linear.bias.clone().half() - - pack_num = 32 // awq_linear.w_bit - - intweight = [] - for idx in range(awq_linear.in_features): - intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[idx // group_size]) / awq_linear.scales[idx // group_size]).to(torch.int)[:, None]) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.to(dtype=torch.int32) - qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device) - - for col in range(intweight.shape[1] // pack_num): - if awq_linear.w_bit == 4: - order_map = [0, 2, 4, 6, 1, 3, 5, 7] - else: - raise NotImplementedError("Only 4-bit are supported for now.") - for i in range(pack_num): - qweight_col = intweight[:, col * pack_num + order_map[i]] - qweight[:, col] |= qweight_col << (i * awq_linear.w_bit) - awq_linear.qweight = qweight - - zeros = zeros.to(dtype=torch.int32) - qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=zeros.device) - - for col in range(zeros.shape[1] // pack_num): - if awq_linear.w_bit == 4: - order_map = [0, 2, 4, 6, 1, 3, 5, 7] - else: - raise NotImplementedError("Only 4-bit are supported for now.") - for i in range(pack_num): - qzero_col = zeros[:, col * pack_num + order_map[i]] - qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit) - awq_linear.qzeros = qzeros - - return awq_linear - @torch.no_grad() def forward(self, x): out_shape = x.shape[:-1] + (self.out_features, ) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index f0227177..3fb3766a 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -256,12 +256,7 @@ def get_linear(weight, bias, quantize): raise NotImplementedError( f"The passed weight is not `awq` compatible, loader needs to be updated." ) - in_features = qweight.shape[0] - out_features = qweight.shape[1] * 32 // bits - linear = WQLinear(w_bit=bits, group_size=groupsize, in_features=in_features, out_features=out_features, bias=bias is not None, dev=qweight.device) - linear.qweight = qweight - linear.qzeros = qzeros - linear.scales = scales + linear = WQLinear(w_bit=bits, group_size=groupsize, qweight=qweight, qzeros=qzeros, scales=scales, bias=bias is not None) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear