text-generation-inference/server/text_generation_server/layers/awq/quantize/qmodule.py
Wang, Yi A 0b02d45a05 add gptq and awq int4 support in intel platform
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-08-21 22:47:34 -07:00

72 lines
2.4 KiB
Python

# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
else:
import awq_inference_engine # with CUDA kernels
# class ScaledActivation(nn.Module):
# def __init__(self, module, scales):
# super().__init__()
# self.act = module
# self.scales = nn.Parameter(scales.data)
#
# def forward(self, x):
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
if SYSTEM == "ipex":
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.in_features,
self.out_features,
bias=self.bias,
group_size=self.group_size,
quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
if SYSTEM == "ipex":
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
else:
out = awq_inference_engine.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)