# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py

from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM

if SYSTEM == "ipex":
    import intel_extension_for_pytorch as ipex
else:
    import awq_inference_engine  # with CUDA kernels


# class ScaledActivation(nn.Module):
#     def __init__(self, module, scales):
#         super().__init__()
#         self.act = module
#         self.scales = nn.Parameter(scales.data)
#
#     def forward(self, x):
#         return self.act(x) / self.scales.view(1, 1, -1).to(x.device)


class WQLinear(nn.Module):
    def __init__(
        self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
    ):
        super().__init__()

        if w_bit not in [4]:
            raise NotImplementedError("Only 4-bit are supported for now.")

        self.in_features = qweight.shape[0]
        self.out_features = qweight.shape[1] * 32 // w_bit

        self.w_bit = w_bit
        self.group_size = group_size if group_size != -1 else self.in_features
        # quick sanity check (make sure aligment)
        assert self.in_features % self.group_size == 0
        assert self.out_features % (32 // self.w_bit) == 0

        self.qweight = qweight
        self.qzeros = qzeros
        self.scales = scales
        self.bias = bias
        if SYSTEM == "ipex":
            self.woq_linear = (
                ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
                    self.qweight,
                    self.scales,
                    self.qzeros,
                    self.in_features,
                    self.out_features,
                    bias=self.bias,
                    group_size=self.group_size,
                    quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM,
                    dtype=ipex.llm.quantization.QuantDtype.INT4,
                )
            )

    @torch.no_grad()
    def forward(self, x):
        out_shape = x.shape[:-1] + (self.out_features,)
        if SYSTEM == "ipex":
            out = self.woq_linear(x.reshape(-1, x.shape[-1]))
        else:
            out = awq_inference_engine.gemm_forward_cuda(
                x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
            )
        out = out + self.bias if self.bias is not None else out
        return out.reshape(out_shape)