mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Refactored WQLinear
This commit is contained in:
parent
f85a6f853e
commit
5d0973f484
@ -1,4 +1,4 @@
|
|||||||
# Copied from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
@ -17,77 +17,29 @@ class ScaledActivation(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class WQLinear(nn.Module):
|
class WQLinear(nn.Module):
|
||||||
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
|
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if w_bit not in [4]:
|
if w_bit not in [4]:
|
||||||
raise NotImplementedError("Only 4-bit are supported for now.")
|
raise NotImplementedError("Only 4-bit are supported for now.")
|
||||||
|
|
||||||
self.in_features = in_features
|
self.in_features = qweight.shape[0]
|
||||||
self.out_features = out_features
|
self.out_features = qweight.shape[1] * 32 // w_bit
|
||||||
|
|
||||||
self.w_bit = w_bit
|
self.w_bit = w_bit
|
||||||
self.group_size = group_size if group_size != -1 else in_features
|
self.group_size = group_size if group_size != -1 else self.in_features
|
||||||
# quick sanity check (make sure aligment)
|
# quick sanity check (make sure aligment)
|
||||||
assert self.in_features % self.group_size == 0
|
assert self.in_features % self.group_size == 0
|
||||||
assert out_features % (32 // self.w_bit) == 0
|
assert self.out_features % (32 // self.w_bit) == 0
|
||||||
|
|
||||||
self.register_buffer('qweight', torch.zeros((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
|
self.register_buffer('qweight', qweight)
|
||||||
self.register_buffer('qzeros', torch.zeros((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
|
self.register_buffer('qzeros', qzeros)
|
||||||
self.register_buffer('scales', torch.zeros((in_features // self.group_size, out_features), dtype=torch.float16, device=dev))
|
self.register_buffer('scales', scales)
|
||||||
if bias:
|
if bias:
|
||||||
self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev))
|
self.register_buffer('bias', bias)
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None):
|
|
||||||
awq_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device)
|
|
||||||
if init_only: # just prepare for loading sd
|
|
||||||
return awq_linear
|
|
||||||
|
|
||||||
# need scales and zeros info for real quantization
|
|
||||||
assert scales is not None and zeros is not None
|
|
||||||
scale_zeros = zeros * scales
|
|
||||||
|
|
||||||
awq_linear.scales = scales.clone().half()
|
|
||||||
if linear.bias is not None:
|
|
||||||
awq_linear.bias = linear.bias.clone().half()
|
|
||||||
|
|
||||||
pack_num = 32 // awq_linear.w_bit
|
|
||||||
|
|
||||||
intweight = []
|
|
||||||
for idx in range(awq_linear.in_features):
|
|
||||||
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[idx // group_size]) / awq_linear.scales[idx // group_size]).to(torch.int)[:, None])
|
|
||||||
intweight = torch.cat(intweight, dim=1)
|
|
||||||
intweight = intweight.t().contiguous()
|
|
||||||
intweight = intweight.to(dtype=torch.int32)
|
|
||||||
qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device)
|
|
||||||
|
|
||||||
for col in range(intweight.shape[1] // pack_num):
|
|
||||||
if awq_linear.w_bit == 4:
|
|
||||||
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Only 4-bit are supported for now.")
|
|
||||||
for i in range(pack_num):
|
|
||||||
qweight_col = intweight[:, col * pack_num + order_map[i]]
|
|
||||||
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
|
|
||||||
awq_linear.qweight = qweight
|
|
||||||
|
|
||||||
zeros = zeros.to(dtype=torch.int32)
|
|
||||||
qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=zeros.device)
|
|
||||||
|
|
||||||
for col in range(zeros.shape[1] // pack_num):
|
|
||||||
if awq_linear.w_bit == 4:
|
|
||||||
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Only 4-bit are supported for now.")
|
|
||||||
for i in range(pack_num):
|
|
||||||
qzero_col = zeros[:, col * pack_num + order_map[i]]
|
|
||||||
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
|
|
||||||
awq_linear.qzeros = qzeros
|
|
||||||
|
|
||||||
return awq_linear
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out_shape = x.shape[:-1] + (self.out_features, )
|
out_shape = x.shape[:-1] + (self.out_features, )
|
||||||
|
@ -256,12 +256,7 @@ def get_linear(weight, bias, quantize):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||||
)
|
)
|
||||||
in_features = qweight.shape[0]
|
linear = WQLinear(w_bit=bits, group_size=groupsize, qweight=qweight, qzeros=qzeros, scales=scales, bias=bias is not None)
|
||||||
out_features = qweight.shape[1] * 32 // bits
|
|
||||||
linear = WQLinear(w_bit=bits, group_size=groupsize, in_features=in_features, out_features=out_features, bias=bias is not None, dev=qweight.device)
|
|
||||||
linear.qweight = qweight
|
|
||||||
linear.qzeros = qzeros
|
|
||||||
linear.scales = scales
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||||
return linear
|
return linear
|
||||||
|
Loading…
Reference in New Issue
Block a user