mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 07:22:07 +00:00
128 lines
4.2 KiB
Python
128 lines
4.2 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
|
|
|
|
class Quantizer(nn.Module):
|
|
|
|
def __init__(self, shape=1):
|
|
super(Quantizer, self).__init__()
|
|
self.register_buffer('maxq', torch.tensor(0))
|
|
self.register_buffer('scale', torch.zeros(shape))
|
|
self.register_buffer('zero', torch.zeros(shape))
|
|
|
|
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False):
|
|
|
|
self.maxq = torch.tensor(2**bits - 1)
|
|
self.perchannel = perchannel
|
|
self.sym = sym
|
|
self.mse = mse
|
|
self.norm = norm
|
|
self.grid = grid
|
|
self.maxshrink = maxshrink
|
|
if trits:
|
|
self.maxq = torch.tensor(-1)
|
|
self.scale = torch.zeros_like(self.scale)
|
|
|
|
def _quantize(self, x, scale, zero, maxq):
|
|
if maxq < 0:
|
|
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
|
|
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
|
return scale * (q - zero)
|
|
|
|
def find_params(self, x, weight=False):
|
|
dev = x.device
|
|
self.maxq = self.maxq.to(dev)
|
|
|
|
shape = x.shape
|
|
if self.perchannel:
|
|
if weight:
|
|
x = x.flatten(1)
|
|
else:
|
|
if len(shape) == 4:
|
|
x = x.permute([1, 0, 2, 3])
|
|
x = x.flatten(1)
|
|
if len(shape) == 3:
|
|
x = x.reshape((-1, shape[-1])).t()
|
|
if len(shape) == 2:
|
|
x = x.t()
|
|
else:
|
|
x = x.flatten().unsqueeze(0)
|
|
|
|
tmp = torch.zeros(x.shape[0], device=dev)
|
|
xmin = torch.minimum(x.min(1)[0], tmp)
|
|
xmax = torch.maximum(x.max(1)[0], tmp)
|
|
|
|
if self.sym:
|
|
xmax = torch.maximum(torch.abs(xmin), xmax)
|
|
tmp = xmin < 0
|
|
if torch.any(tmp):
|
|
xmin[tmp] = -xmax[tmp]
|
|
tmp = (xmin == 0) & (xmax == 0)
|
|
xmin[tmp] = -1
|
|
xmax[tmp] = +1
|
|
|
|
if self.maxq < 0:
|
|
self.scale = xmax
|
|
self.zero = xmin
|
|
else:
|
|
self.scale = (xmax - xmin) / self.maxq
|
|
if self.sym:
|
|
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
|
|
else:
|
|
self.zero = torch.round(-xmin / self.scale)
|
|
|
|
if self.mse:
|
|
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
|
for i in range(int(self.maxshrink * self.grid)):
|
|
p = 1 - i / self.grid
|
|
xmin1 = p * xmin
|
|
xmax1 = p * xmax
|
|
scale1 = (xmax1 - xmin1) / self.maxq
|
|
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
|
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
|
|
q -= x
|
|
q.abs_()
|
|
q.pow_(self.norm)
|
|
err = torch.sum(q, 1)
|
|
tmp = err < best
|
|
if torch.any(tmp):
|
|
best[tmp] = err[tmp]
|
|
self.scale[tmp] = scale1[tmp]
|
|
self.zero[tmp] = zero1[tmp]
|
|
if not self.perchannel:
|
|
if weight:
|
|
tmp = shape[0]
|
|
else:
|
|
tmp = shape[1] if len(shape) != 3 else shape[2]
|
|
self.scale = self.scale.repeat(tmp)
|
|
self.zero = self.zero.repeat(tmp)
|
|
|
|
if weight:
|
|
shape = [-1] + [1] * (len(shape) - 1)
|
|
self.scale = self.scale.reshape(shape)
|
|
self.zero = self.zero.reshape(shape)
|
|
return
|
|
if len(shape) == 4:
|
|
self.scale = self.scale.reshape((1, -1, 1, 1))
|
|
self.zero = self.zero.reshape((1, -1, 1, 1))
|
|
if len(shape) == 3:
|
|
self.scale = self.scale.reshape((1, 1, -1))
|
|
self.zero = self.zero.reshape((1, 1, -1))
|
|
if len(shape) == 2:
|
|
self.scale = self.scale.unsqueeze(0)
|
|
self.zero = self.zero.unsqueeze(0)
|
|
|
|
def quantize(self, x):
|
|
if self.ready():
|
|
return self._quantize(x, self.scale, self.zero, self.maxq)
|
|
|
|
return x
|
|
|
|
def enabled(self):
|
|
return self.maxq > 0
|
|
|
|
def ready(self):
|
|
return torch.all(self.scale != 0)
|