from dataclasses import dataclass
from typing import List, Optional, Union

import numpy
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.marlin.util import (
    _check_marlin_kernels,
    marlin_zero_points,
    permute_scales,
    unpack_cols,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader

try:
    import marlin_kernels
except ImportError:
    marlin_kernels = None

try:
    major, _minor = torch.cuda.get_device_capability()
    has_sm_8_0 = major >= 8
except Exception:
    has_sm_8_0 = False


GPTQ_MARLIN_BITS = [4, 8]
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16


def can_use_gptq_marlin(
    *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
) -> bool:
    return (
        SYSTEM == "cuda"
        and marlin_kernels is not None
        and has_sm_8_0
        and quantize in {"awq", "gptq"}
        and quant_method in {"awq", "gptq"}
        and bits in GPTQ_MARLIN_BITS
        and groupsize in GPTQ_MARLIN_GROUP_SIZES
        # We only suppord asymmetric quantization for AWQ.
        and (sym or quant_method == "awq")
    )


class GPTQMarlinWeightsLoader(WeightsLoader):
    """
    Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
    """

    def __init__(
        self,
        *,
        bits: int,
        desc_act: bool,
        groupsize: int,
        quant_method: str,
        quantize: str,
        sym: bool,
    ):
        self.bits = bits
        self.desc_act = desc_act
        self.groupsize = groupsize
        self.quant_method = quant_method
        self.quantize = quantize
        self.sym = sym

    def get_weights(self, weights: Weights, prefix: str):
        log_once(logger.info, "Using GPTQ-Marlin kernels")
        try:
            qweight = weights.get_tensor(f"{prefix}.qweight")
        except RuntimeError:
            raise RuntimeError(
                f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
            )

        if not self.sym:
            qzeros = weights.get_tensor(f"{prefix}.qzeros")
        else:
            qzeros = None

        if self.quant_method == "awq":
            g_idx = None
        else:
            g_idx = weights.get_tensor(f"{prefix}.g_idx")
        scales = weights.get_tensor(f"{prefix}.scales")

        return repack_gptq_for_marlin(
            qweight=qweight,
            scales=scales,
            qzeros=qzeros,
            g_idx=g_idx,
            bits=self.bits,
            desc_act=self.desc_act,
            groupsize=self.groupsize,
            quant_method=self.quant_method,
            sym=self.sym,
            sharded_infeatures=False,
        )

    def get_weights_col_packed(
        self,
        weights: Weights,
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):

        try:
            qweight = weights.get_packed_sharded(
                f"{prefix}.qweight", dim=1, block_sizes=block_sizes
            )
        except RuntimeError:
            raise RuntimeError(
                f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
            )
        scales = weights.get_packed_sharded(
            f"{prefix}.scales", dim=1, block_sizes=block_sizes
        )
        scales = scales.to(dtype=weights.dtype)

        if not self.sym:
            qzeros = weights.get_packed_sharded(
                f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
            )
        else:
            qzeros = None

        if self.quant_method == "awq":
            g_idx = None
        else:
            g_idx = weights.get_tensor(f"{prefix}.g_idx")
        return repack_gptq_for_marlin(
            qweight=qweight,
            scales=scales,
            qzeros=qzeros,
            g_idx=g_idx,
            bits=self.bits,
            desc_act=self.desc_act,
            groupsize=self.groupsize,
            quant_method=self.quant_method,
            sym=self.sym,
            sharded_infeatures=False,
        )

    def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
        try:
            qweight = torch.cat(
                [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
            )
        except RuntimeError:
            raise RuntimeError(
                f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
            )

        scales = torch.cat(
            [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
        )

        if not self.sym:
            qzeros = torch.cat(
                [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
            )
        else:
            qzeros = None

        if self.quant_method == "awq":
            g_idx = None
        else:
            w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
            for w2 in w[1:]:
                torch.testing.assert_close(w2, w[0])
            g_idx = w[0]

        return repack_gptq_for_marlin(
            qweight=qweight,
            scales=scales,
            qzeros=qzeros,
            g_idx=g_idx,
            bits=self.bits,
            desc_act=self.desc_act,
            groupsize=self.groupsize,
            quant_method=self.quant_method,
            sym=self.sym,
            sharded_infeatures=False,
        )

    def get_weights_row(self, weights: Weights, prefix: str):
        log_once(logger.info, "Using GPTQ-Marlin kernels")
        try:
            qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
        except RuntimeError:
            raise RuntimeError(
                f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
            )

        if not self.sym:
            if self.desc_act or self.groupsize == -1:
                qzeros = weights.get_tensor(f"{prefix}.qzeros")
            else:
                qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
        else:
            qzeros = None

        if self.quant_method == "awq":
            g_idx = None
        else:
            g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)

        if self.desc_act or self.groupsize == -1:
            scales = weights.get_tensor(f"{prefix}.scales")
        else:
            scales = weights.get_sharded(f"{prefix}.scales", dim=0)

        sharded_in_features = weights.process_group.size() > 1

        return repack_gptq_for_marlin(
            qweight=qweight,
            scales=scales,
            qzeros=qzeros,
            g_idx=g_idx,
            bits=self.bits,
            desc_act=self.desc_act,
            groupsize=self.groupsize,
            quant_method=self.quant_method,
            sym=self.sym,
            sharded_infeatures=sharded_in_features,
        )

    def _get_gptq_params(self, weights: Weights):
        if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
            self.bits = weights.get_tensor("gptq_bits").item()
            self.groupsize = weights.get_tensor("gptq_groupsize").item()
            self.desc_act = False
            # `server quantize` used asymmetric quantization unconditionally
            # before the `gptq_sym` setting tensor was added.
            self.sym = (
                weights.get_tensor("gptq_sym").item()
                if weights._has_tensor("gptq_sym")
                else False
            )
            self.quant_method = "gptq"


@dataclass
class GPTQMarlinWeight(Weight):
    """
    Repacked GPTQ Marlin weights.
    """

    qweight: torch.Tensor
    qzeros: torch.Tensor
    scales: torch.Tensor
    g_idx: torch.Tensor
    perm: torch.Tensor
    bits: int
    is_full_k: bool

    def __post_init__(self):
        assert self.qweight.dtype == torch.int32
        assert self.scales.dtype == torch.float16
        assert self.g_idx.dtype == torch.int32
        assert self.perm.dtype == torch.int32

    def get_linear(self, bias: torch.Tensor):
        return GPTQMarlinLinear(
            weight=self,
            bias=bias,
        )


def repack_gptq_for_marlin(
    *,
    qweight: torch.Tensor,
    qzeros: Optional[torch.Tensor],
    scales: torch.Tensor,
    g_idx: Optional[torch.Tensor],
    bits: int,
    desc_act: bool,
    groupsize: int,
    quant_method: str,
    sym: bool,
    sharded_infeatures: bool,
) -> GPTQMarlinWeight:
    """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
    _check_marlin_kernels()
    assert marlin_kernels is not None

    if bits not in GPTQ_MARLIN_BITS:
        supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
        raise RuntimeError(
            f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
        )

    if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
        supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
        raise RuntimeError(
            f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
        )
    if not (sym or quant_method == "awq"):
        raise RuntimeError(
            "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
        )

    log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.")

    weights_per_int = 32 // bits
    in_features = qweight.shape[0]
    out_features = qweight.shape[1]

    # AWQ uses column packing, GPTQ uses row packing
    if quant_method == "awq":
        out_features *= weights_per_int
    else:
        in_features *= weights_per_int

    if in_features % groupsize != 0:
        raise ValueError(
            f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
        )

    if g_idx is not None and desc_act and groupsize != -1:
        perm = torch.argsort(g_idx).to(torch.int)
        g_idx = g_idx[perm]
    else:
        perm = torch.empty(0, dtype=torch.int, device=qweight.device)
        g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)

    if quant_method == "awq":
        repacked = marlin_kernels.awq_marlin_repack(
            qweight, in_features, out_features, bits
        )
        if qzeros is not None:
            qzeros = awq_to_marlin_zero_points(
                qzeros,
                in_features // groupsize,
                out_features,
                bits,
            )

    else:
        repacked = marlin_kernels.gptq_marlin_repack(
            qweight, perm, in_features, out_features, bits
        )

    if qzeros is None:
        qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)

    scales = permute_scales(scales)

    is_full_k = not (desc_act and sharded_infeatures)

    return GPTQMarlinWeight(
        qweight=repacked,
        qzeros=qzeros,
        scales=scales,
        g_idx=g_idx,
        perm=perm,
        bits=bits,
        is_full_k=is_full_k,
    )


class GPTQMarlinLinear(nn.Module):
    """
    Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
    kernels.
    """

    def __init__(
        self,
        *,
        weight: GPTQMarlinWeight,
        bias: Optional[torch.Tensor],
    ):
        super().__init__()

        _check_marlin_kernels()
        assert marlin_kernels is not None

        in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
        out_features = weight.scales.shape[1]
        _check_valid_shape(in_features=in_features, out_features=out_features)

        self.bits = weight.bits
        self.is_full_k = weight.is_full_k

        self.qweight = weight.qweight
        self.qzeros = weight.qzeros
        self.scales = weight.scales
        self.g_idx = weight.g_idx
        self.perm = weight.perm
        if bias is not None:
            self.bias = bias
        else:
            self.bias = None

        self.workspace = torch.zeros(
            out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device
        )

    def forward(self, A: torch.Tensor) -> torch.Tensor:
        assert marlin_kernels is not None

        A_flat = A.view(-1, A.shape[-1])
        C = marlin_kernels.gptq_marlin_gemm(
            A_flat,
            self.qweight,
            self.scales,
            self.qzeros,
            self.g_idx,
            self.perm,
            self.workspace,
            self.bits,
            A_flat.shape[0],
            self.scales.shape[1],
            A_flat.shape[1],
            self.is_full_k,
            self.qzeros.numel() > 0,
            True,
        )
        C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))

        if self.bias is not None:
            C += self.bias

        return C


def awq_to_marlin_zero_points(
    q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
    # AWQ zero-points are quantized and packed on the column dim.
    # In addition, the values are permuted based on dequantizer.
    # Here we undo both of these, and then apply marlin permutation
    # and pack it back.
    q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)

    # Undo interleaving (use argsort(..) to get inverse perm)
    if num_bits == 4:
        undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
    elif num_bits == 8:
        undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
    else:
        raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))

    q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
    q_zp = q_zp.reshape((-1, size_n)).contiguous()

    marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
    return marlin_zp


def _check_valid_shape(in_features: int, out_features: int):
    if (in_features % 128 != 0 or out_features % 64 != 0) and (
        in_features % 64 != 0 or out_features % 128 != 0
    ):
        raise ValueError(
            f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})."
            " The shape elements must be divisible by (128, 64) or (64, 128)."
        )