from dataclasses import dataclass
import os
from typing import Optional, Tuple, Type, Union, List

import torch
from loguru import logger

from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import (
    Weight,
    WeightsLoader,
    UnquantizedWeight,
    Weights,
)
from text_generation_server.utils.log import log_once

try:
    import marlin_kernels
except ImportError:
    marlin_kernels = None


if SYSTEM == "cuda" and marlin_kernels is not None:
    major, minor = torch.cuda.get_device_capability()
    CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
        major * 10 + minor
    )
else:
    CUTLASS_FP8_AVAILABLE = False


def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
    """
    Return an FP8 linear `Module` that is compatible with the current system.
    """

    if SYSTEM == "cuda":

        major, _ = torch.cuda.get_device_capability()
        # Marlin is W8A16, use it when:
        #
        # - On capability 8.x where x < 8: W8A8 FP8 GEMM is not supported.
        # - On capability 8.9: W8A8 FP8 GEMM is supported, but Marlin-FP8 is faster.
        # - On capability 9.x when force_w8a16: cutlass kernels do not support W8A16.
        if (major == 8 or (major == 9 and force_w8a16)) and os.getenv(
            "USE_CUTLASS_W8A8", "0"
        ) != "1":
            # NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin
            #       gives better decoding throughput on L4 and L40.
            from text_generation_server.layers.marlin import GPTQMarlinFP8Linear

            return GPTQMarlinFP8Linear

    # On other systems let Torch decide if the hardware supports FP8.
    return Fp8Linear


def normalize_e4m3fn_to_e4m3fnuz(
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    assert weight.dtype == torch.float8_e4m3fn
    # The bits pattern 10000000(-128) represents zero in e4m3fn
    # but NaN in e4m3fnuz. So here we set it to 0.
    # https://onnx.ai/onnx/technical/float8.html
    weight_as_int8 = weight.view(torch.int8)
    ROCM_FP8_NAN_AS_INT = -128
    weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
    weight = weight_as_int8.view(torch.float8_e4m3fnuz)

    # For the same bits representation, e4m3fnuz value is half of
    # the e4m3fn value, so we should double the scaling factor to
    # get the same dequantized value.
    # https://onnx.ai/onnx/technical/float8.html
    weight_scale = weight_scale * 2.0
    if input_scale is not None:
        input_scale = input_scale * 2.0
    return weight, weight_scale, input_scale


def fp8_quantize(
    weight: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
    scale_upper_bound: Optional[torch.Tensor] = None,
    qdtype: torch.dtype = torch.float8_e4m3fn,
    scalar: bool = False,
):
    """
    This function returns a reciprocal of the scale, so that a tensor can be unscaled
    by multiplying it with the returned scale. If a scale is given through the `scale`
    argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
    be used without modification).
    """
    if marlin_kernels is not None:
        shape = weight.shape
        qweight, scale = marlin_kernels.scaled_fp8_quant(
            weight.reshape(-1, shape[-1]),
            dtype=qdtype,
            scale=scale,
            scale_ub=scale_upper_bound,
            # TODO: don't do this when we have to use the Torch kernel.
            use_per_token_if_dynamic=not scalar,
        )

        return qweight.reshape(shape), scale

    finfo = torch.finfo(qdtype)

    if scale is None:
        # Calculate the scale as dtype max divided by absmax
        scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
        # scale and clamp the tensor to bring it to
        # the representative range of float8 data type
        # (as default cast is unsaturated)
        qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
        scale = scale.float().reciprocal()
    else:
        # Use reciprocal to avoid more expensive division.
        qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)

    # Return both float8 data and the inverse scale (as float),
    # as both required as inputs to torch._scaled_mm
    qweight = qweight.to(qdtype)

    if SYSTEM == "rocm":
        qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)

    return qweight, scale


class HybridFP8UnquantLoader(WeightsLoader):
    """Weight loader that loads FP8 and unquantized Torch tensors."""

    def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
        self.activation_scale_ub = activation_scale_ub
        self.to_fp8 = to_fp8

    def get_weights(self, weights: "Weights", prefix: str):
        w = weights.get_tensor(f"{prefix}.weight")

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
            scale = (
                weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
                .reshape(-1)
                .expand(w.shape[0])
            )

            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                ).reshape(-1)

            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                input_scale=input_scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_col_packed(
        self,
        weights: Weights,
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        w = weights.get_packed_sharded(
            f"{prefix}.weight", dim=0, block_sizes=block_sizes
        )

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
            scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
            if scale.numel() > 1:
                scale = weights.get_packed_sharded(
                    f"{prefix}.weight_scale",
                    dim=0,
                    block_sizes=block_sizes,
                    to_dtype=False,
                )
            scale = scale.reshape(-1).expand(w.shape[0])

            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                )
                if input_scale.numel() > 1:
                    input_scale = weights.get_packed_sharded(
                        f"{prefix}.input_scale",
                        dim=0,
                        block_sizes=block_sizes,
                        to_dtype=False,
                    )
                input_scale = input_scale.reshape(-1).max()

            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                input_scale=input_scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
        w = [
            weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
        ]
        shapes = [x.shape for x in w]

        # Concat then send to the device
        w = torch.cat(w, dim=dim).to(weights.device)

        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
            scale = [
                _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
                for p, shape in zip(prefixes, shapes)
            ]
            scale = torch.cat(scale, dim=0).reshape(-1)

            input_scale = [
                _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
                for p, shape in zip(prefixes, shapes)
                if weights.has_tensor(f"{p}.input_scale")
            ]
            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
            input_scale = (
                torch.cat(input_scale, dim=0).reshape(-1).max()
                if len(input_scale) != 0
                else None
            )

            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                input_scale=input_scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_row(self, weights: "Weights", prefix: str):
        w = weights.get_sharded(f"{prefix}.weight", dim=1)
        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
            scale = (
                weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
                .reshape(-1)
                .expand(w.shape[0])
            )
            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                ).reshape(-1)

            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                input_scale=input_scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)


@dataclass
class Fp8Weight(Weight):
    weight: torch.Tensor
    dtype: torch.dtype
    weight_scale: Optional[torch.Tensor] = None
    input_scale: Optional[torch.Tensor] = None
    activation_scale_ub: Optional[float] = None
    force_w8a16: bool = False

    def get_linear(self, bias: torch.Tensor):
        if self.weight_scale is None:
            return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
                self.weight, bias, self.dtype
            )
        # This is not checked by the fbgemm kernels, but they require contiguous
        # memory. Can be non-contiguous when we e.g. expand from scalars.
        self.weight_scale = self.weight_scale.contiguous()
        return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
            weight=self.weight,
            scale=self.weight_scale,
            dtype=self.dtype,
            bias=bias,
            input_scale=self.input_scale,
            scale_upper_bound=self.activation_scale_ub,
        )


class Fp8Linear(torch.nn.Module):
    _device_identity_cache = {}

    def __init__(
        self,
        qweight: torch.Tensor,
        scale: torch.Tensor,
        dtype: torch.dtype,
        bias: Optional[torch.Tensor] = None,
        input_scale: Optional[torch.Tensor] = None,
        scale_upper_bound: Optional[float] = None,
    ) -> None:
        super().__init__()
        if CUTLASS_FP8_AVAILABLE:
            log_once(logger.info, "Using cutlass w8a8 kernels")
        if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
            qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                weight=qweight, weight_scale=scale
            )

        self.dtype = dtype
        self.qweight = qweight
        self.scale = scale.float()
        self.input_scale = input_scale.float() if input_scale is not None else None

        if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
            self.scale_upper_bound = torch.tensor(
                scale_upper_bound, dtype=torch.float32, device=qweight.device
            )
        else:
            self.scale_upper_bound = scale_upper_bound

        self.bias = bias if bias is not None else None

    @classmethod
    def from_unquant(cls, weight, bias, dtype):
        qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
        return cls(
            qweight=qweight,
            scale=scale,
            dtype=dtype,
            bias=bias,
            input_scale=None,
            scale_upper_bound=None,
        )

    @classmethod
    def from_fp8(
        cls,
        weight: torch.Tensor,
        scale: torch.Tensor,
        dtype: torch.dtype,
        bias: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> "Fp8Linear":
        input_scale = kwargs.get("input_scale", None)
        scale_upper_bound = kwargs.get("scale_upper_bound", None)

        return cls(
            qweight=weight,
            scale=scale,
            input_scale=input_scale,
            scale_upper_bound=scale_upper_bound,
            bias=bias,
            dtype=dtype,
        )

    @classmethod
    def get_shared_device_identity(cls, device):
        # Input scaling factors are no longer optional in _scaled_mm starting
        # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
        if device not in cls._device_identity_cache:
            cls._device_identity_cache[device] = torch.ones(1, device=device)
        return cls._device_identity_cache[device]

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if CUTLASS_FP8_AVAILABLE:
            # cutlass FP8 supports per-token scales, so get non-scalar scales.
            qinput, scale = fp8_quantize(
                input, scale_upper_bound=self.scale_upper_bound, scalar=False
            )
            return marlin_kernels.cutlass_scaled_mm(
                qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
            )

        qinput, scale = fp8_quantize(
            input,
            self.input_scale,
            scale_upper_bound=self.scale_upper_bound,
            scalar=True,
        )

        per_tensor_weights = self.scale.numel() == 1
        per_tensor_activations = scale.numel() == 1

        if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
            output = torch._scaled_mm(
                qinput,
                self.qweight.t(),
                out_dtype=self.dtype,
                scale_a=scale,
                scale_b=self.scale,
                bias=self.bias,
            )

            if isinstance(output, tuple) and len(output) == 2:
                output = output[0]
        else:
            device_identity = None
            if SYSTEM == "rocm":
                device_identity = self.get_shared_device_identity(self.qweight.device)

            output = torch._scaled_mm(
                qinput,
                self.qweight.t(),
                scale_a=device_identity,
                scale_b=device_identity,
                out_dtype=torch.float32,
            )
            if isinstance(output, tuple) and len(output) == 2:
                output = output[0]

            output = output * scale * self.scale.t()
            if self.bias is not None:
                output = output + self.bias

            output = output.to(dtype=self.dtype)

        return output


def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
    scale = weights.get_tensor(prefix, to_dtype=False)
    if scale.numel() > 1:
        scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
    return scale.reshape(-1).expand(shape[0])