from typing import Optional
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex


class WQLinear(nn.Module):
    def __init__(
        self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
    ):
        super().__init__()

        if w_bit not in [4]:
            raise NotImplementedError("Only 4-bit are supported for now.")

        self.in_features = qweight.shape[0]
        self.out_features = qweight.shape[1] * 32 // w_bit

        self.w_bit = w_bit
        self.group_size = group_size if group_size != -1 else self.in_features
        # quick sanity check (make sure aligment)
        assert self.in_features % self.group_size == 0
        assert self.out_features % (32 // self.w_bit) == 0

        self.qweight = qweight
        self.qzeros = qzeros
        self.scales = scales
        self.bias = bias
        self.woq_linear = (
            ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
                self.qweight,
                self.scales,
                self.qzeros,
                self.in_features,
                self.out_features,
                bias=self.bias,
                group_size=self.group_size,
                quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM,
                dtype=ipex.llm.quantization.QuantDtype.INT4,
            )
        )

    @torch.no_grad()
    def forward(self, x):
        out_shape = x.shape[:-1] + (self.out_features,)
        out = self.woq_linear(x.reshape(-1, x.shape[-1]))
        return out.reshape(out_shape)