mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
|
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import awq_inference_engine # with CUDA kernels
|
|
|
|
|
|
# class ScaledActivation(nn.Module):
|
|
# def __init__(self, module, scales):
|
|
# super().__init__()
|
|
# self.act = module
|
|
# self.scales = nn.Parameter(scales.data)
|
|
#
|
|
# def forward(self, x):
|
|
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
|
|
|
|
|
class WQLinear(nn.Module):
|
|
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
|
|
super().__init__()
|
|
|
|
if w_bit not in [4]:
|
|
raise NotImplementedError("Only 4-bit are supported for now.")
|
|
|
|
self.in_features = qweight.shape[0]
|
|
self.out_features = qweight.shape[1] * 32 // w_bit
|
|
|
|
self.w_bit = w_bit
|
|
self.group_size = group_size if group_size != -1 else self.in_features
|
|
# quick sanity check (make sure aligment)
|
|
assert self.in_features % self.group_size == 0
|
|
assert self.out_features % (32 // self.w_bit) == 0
|
|
|
|
self.qweight = qweight
|
|
self.qzeros = qzeros
|
|
self.scales = scales
|
|
if bias:
|
|
self.bias = bias
|
|
else:
|
|
self.bias = None
|
|
|
|
@torch.no_grad()
|
|
def forward(self, x):
|
|
out_shape = x.shape[:-1] + (self.out_features,)
|
|
out = awq_inference_engine.gemm_forward_cuda(
|
|
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
|
|
)
|
|
out = out + self.bias if self.bias is not None else out
|
|
return out.reshape(out_shape)
|