from typing import Optional

import torch
import torch.nn as nn
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.layers.marlin.gptq import _check_valid_shape
from text_generation_server.layers.marlin.util import (
    _check_marlin_kernels,
    permute_scales,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel

if SYSTEM == "cuda":
    quantization = load_kernel(
        module="quantization", repo_id="kernels-community/quantization"
    )
else:
    quantization = None


MARLIN_TILE_SIZE = 16


class GPTQMarlinFP8Linear(nn.Module):
    """
    FP8 GPTQ-Marlin linear layer.
    """

    def __init__(
        self,
        qweight: torch.Tensor,
        scales: torch.Tensor,
        bias: Optional[torch.Tensor],
    ) -> None:
        super().__init__()

        _check_marlin_kernels()
        assert quantization is not None

        scales = scales.unsqueeze(0)
        if scales.shape[1] == 1:
            out_features, in_features = qweight.shape
            scales = scales.repeat(1, out_features)
        qweight, scales = repack_fp8_for_marlin(qweight, scales)

        in_features = qweight.shape[0] * MARLIN_TILE_SIZE
        out_features = scales.shape[1]
        _check_valid_shape(in_features=in_features, out_features=out_features)

        self.qweight = qweight
        self.scales = scales
        self.bias = bias if bias is not None else None

        self.workspace = torch.zeros(
            out_features // 64 * 16, dtype=torch.int, device=qweight.device
        )

    @classmethod
    def from_unquant(cls, weight, bias, dtype):
        qweight, scales = fp8_quantize(weight)
        return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)

    @classmethod
    def from_fp8(
        cls,
        weight: torch.Tensor,
        scale: torch.Tensor,
        bias: torch.Tensor,
        dtype: torch.dtype,
        **kwargs,
    ):
        return cls(qweight=weight, scales=scale.to(dtype), bias=bias)

    def forward(self, A: torch.Tensor) -> torch.Tensor:
        assert quantization is not None

        A_flat = A.view(-1, A.shape[-1])
        C = quantization.fp8_marlin_gemm(
            A_flat,
            self.qweight,
            self.scales,
            self.workspace,
            8,
            A_flat.shape[0],
            self.scales.shape[1],
            A_flat.shape[1],
        )
        C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))

        if self.bias is not None:
            C += self.bias

        return C


def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
    """
    Repack FP8 weights to gptq format (packed int32 elements).
    """
    assert fp8_tensor.dtype == torch.float8_e4m3fn

    if fp8_tensor.shape[0] % 4 != 0:
        raise ValueError(
            f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
        )

    # Reshape to prepare for packing
    reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])

    # Convert fp8 to uint8 (byte) representation
    byte_tensor = reshaped.view(torch.uint8)

    # Pack 4 uint8 values into one int32
    packed = torch.zeros(
        fp8_tensor.shape[0] // 4,
        fp8_tensor.shape[1],
        dtype=torch.int32,
        device=fp8_tensor.device,
    )

    for i in range(4):
        packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)

    return packed


def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
    """
    Repack FP8 tensor for GPTQ-Marlin.
    """

    out_features, in_features = weight.shape

    # Torch linear layers weights with shape [out_features, in_features],
    # GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
    # so transpose before packing.
    qweight = pack_fp8_as_int32(weight.t())

    perm = torch.empty(0, dtype=torch.int, device=qweight.device)
    repacked = quantization.gptq_marlin_repack(
        qweight, perm, in_features, out_features, 8
    )

    scales = permute_scales(scales)

    return repacked, scales