import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError
import torch
from loguru import logger
from huggingface_hub import hf_hub_download
import json
from text_generation_server.utils.log import log_once


@dataclass
class _GPTQParams:
    bits: int
    checkpoint_format: Optional[str]
    groupsize: int
    desc_act: bool
    quant_method: str
    sym: bool


class Weights:
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
        aliases: Optional[Dict[str, List[str]]] = None,
        prefix: Optional[str] = None,
    ):
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
        if aliases is None:
            aliases = {}
        self.aliases = aliases
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
        self.prefix = prefix
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

    def get_filename(self, tensor_name: str) -> (str, str):
        names = [tensor_name]
        if self.prefix is not None:
            prefixed = f"{self.prefix}.{tensor_name}"
            names.append(prefixed)
        for name in names:
            filename = self.routing.get(name, None)
            if filename is not None:
                return str(filename), name

            aliases = self.aliases.get(name, [])
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
        raise RuntimeError(f"weight {tensor_name} does not exist")

    def _get_slice(self, tensor_name: str):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

    def get_tensor(self, tensor_name: str, to_device=True):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
        # Special case for gptq which shouldn't convert
        # u4 which are disguised as int32. Exl2 uses int16
        # as well.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
            tensor = tensor.to(dtype=self.dtype)
        if to_device:
            tensor = tensor.to(device=self.device)
        return tensor

    def get_partial_sharded(self, tensor_name: str, dim: int):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        size = slice_.get_shape()[dim]
        block_size = (size + world_size - 1) // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        # Special case for gptq which shouldn't convert
        # u4 which are disguised as int32. exl2 uses int16.
        if tensor.dtype not in (torch.int16, torch.int32):
            tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
        return tensor

    def get_sharded(self, tensor_name: str, dim: int):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
        return self.get_partial_sharded(tensor_name, dim)

    def get_packed_sharded(
        self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
    ) -> torch.Tensor:
        """
        Get a shard from a tensor that packs multiple tensors.

        When a tensor packs multiple tensors (such as QKV or an up
        projection + gate projection), sharding with `get_sharded` is not
        safe since it would not split the packed tensors across shards.

        This method shards a tensor, such that the packed tensors are
        split across shards.

        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
        """
        slice_ = self._get_slice(tensor_name)
        total_size = slice_.get_shape()[dim]
        block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)

        world_size = self.process_group.size()
        rank = self.process_group.rank()

        tensors = []
        block_offset = 0
        for block_size in block_sizes:
            assert (
                block_size % world_size == 0
            ), f"Prepacked tensor cannot be sharded across {world_size} shards"
            shard_block_size = block_size // world_size
            start = rank * shard_block_size
            stop = (rank + 1) * shard_block_size
            if dim == 0:
                tensor = slice_[block_offset + start : block_offset + stop]
            elif dim == 1:
                tensor = slice_[:, block_offset + start : block_offset + stop]
            else:
                raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
            tensors.append(tensor)
            block_offset += block_size
        tensor = torch.cat(tensors, dim=dim)
        tensor = tensor.to(device=self.device)

        # Avoid casting quantizer dtypes.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
            tensor = tensor.to(dtype=self.dtype)

        return tensor

    def get_weights_col_packed_qkv(
        self,
        prefix: str,
        quantize: str,
        num_heads: int,
        num_key_value_heads: int,
    ):
        return self.get_weights_col_packed(
            prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
        )

    def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
        return self.get_weights_col_packed(prefix, quantize, 2)

    def get_weights_col_packed(
        self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
    ):
        """
        Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
        already alternating Q,K,V within the main tensor.

        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
        """
        if quantize in ["gptq", "awq"]:
            from text_generation_server.layers.gptq import GPTQWeight

            try:
                qweight = self.get_packed_sharded(
                    f"{prefix}.qweight", dim=1, block_sizes=block_sizes
                )
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized."
                )

            gptq_params = self._get_gptq_params()

            qzeros = self.get_packed_sharded(
                f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
            )
            scales = self.get_packed_sharded(
                f"{prefix}.scales", dim=1, block_sizes=block_sizes
            )
            scales = scales.to(dtype=self.dtype)

            if quantize == "gptq" and gptq_params.quant_method == "gptq":
                g_idx = self.get_tensor(f"{prefix}.g_idx")
            elif quantize == "gptq" and gptq_params.quant_method == "awq":
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
                from text_generation_server.layers.awq.conversion_utils import (
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                g_idx = (
                    torch.arange(
                        qweight.shape[0] * (32 // gptq_params.bits),
                        device=qweight.device,
                    )
                    // gptq_params.groupsize
                ).to(dtype=torch.int32)
            else:
                g_idx = None

            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
                use_exllama=False,
            )
        elif quantize == "marlin":
            from text_generation_server.layers.marlin import (
                GPTQMarlin24Weight,
                MarlinWeight,
                repack_gptq_for_marlin,
            )

            quant_method = getattr(self, "quant_method", "marlin")
            is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
            if is_marlin_24:
                B = self.get_packed_sharded(
                    f"{prefix}.B_24", dim=1, block_sizes=block_sizes
                )
                B_meta = self.get_packed_sharded(
                    f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
                )
                s = self.get_packed_sharded(
                    f"{prefix}.s", dim=1, block_sizes=block_sizes
                )

                gptq_params = self._get_gptq_params()
                weight = GPTQMarlin24Weight(
                    B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
                )
            elif quant_method == "gptq":
                gptq_params = self._get_gptq_params()
                try:
                    qweight = self.get_packed_sharded(
                        f"{prefix}.qweight", dim=1, block_sizes=block_sizes
                    )
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
                    )

                scales = self.get_packed_sharded(
                    f"{prefix}.scales", dim=1, block_sizes=block_sizes
                )
                g_idx = self.get_tensor(f"{prefix}.g_idx")
                weight = repack_gptq_for_marlin(
                    qweight=qweight,
                    scales=scales,
                    g_idx=g_idx,
                    bits=gptq_params.bits,
                    desc_act=gptq_params.desc_act,
                    groupsize=gptq_params.groupsize,
                    sym=gptq_params.sym,
                    sharded_infeatures=False,
                )
            else:
                B = self.get_packed_sharded(
                    f"{prefix}.B", dim=1, block_sizes=block_sizes
                )
                s = self.get_packed_sharded(
                    f"{prefix}.s", dim=1, block_sizes=block_sizes
                )
                weight = MarlinWeight(B=B, s=s)
        else:
            weight = self.get_packed_sharded(
                f"{prefix}.weight", dim=0, block_sizes=block_sizes
            )
        return weight

    def get_weights_col(self, prefix: str, quantize: str):
        if quantize == "exl2":
            from text_generation_server.layers.exl2 import Exl2Weight

            try:
                q_weight = self.get_tensor(f"{prefix}.q_weight")
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
                )

            q_scale = self.get_tensor(f"{prefix}.q_scale")
            q_invperm = self.get_tensor(f"{prefix}.q_invperm")
            q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
            q_groups = self.get_tensor(f"{prefix}.q_groups")

            return Exl2Weight(
                q_weight=q_weight,
                q_scale=q_scale,
                q_invperm=q_invperm,
                q_scale_max=q_scale_max,
                q_groups=q_groups,
            )

        return self.get_multi_weights_col([prefix], quantize, 0)

    def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
        if quantize == "exl2":
            raise ValueError("get_multi_weights_col is not supported for exl2")
        elif quantize in ["gptq", "awq"]:
            from text_generation_server.layers.gptq import GPTQWeight

            try:
                qweight = torch.cat(
                    [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
                )
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized"
                )

            qzeros = torch.cat(
                [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
            )
            scales = torch.cat(
                [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
            )

            gptq_params = self._get_gptq_params()

            from text_generation_server.layers.gptq import HAS_EXLLAMA

            use_exllama = (
                gptq_params.bits == 4
                and HAS_EXLLAMA
                and quantize == "gptq"
                and not gptq_params.desc_act
            )

            if quantize == "gptq" and gptq_params.quant_method == "gptq":
                w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
                for w2 in w[1:]:
                    torch.testing.assert_close(w2, w[0])
                g_idx = w[0]
            elif quantize == "gptq" and gptq_params.quant_method == "awq":
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
                from text_generation_server.layers.awq.conversion_utils import (
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                if use_exllama:
                    g_idx = None
                else:
                    g_idx = (
                        torch.arange(
                            qweight.shape[0] * (32 // gptq_params.bits),
                            device=qweight.device,
                        )
                        // gptq_params.groupsize
                    ).to(dtype=torch.int32)
            else:
                g_idx = None

            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
                use_exllama=use_exllama,
            )
        elif quantize == "marlin":
            from text_generation_server.layers.gptq import GPTQWeight
            from text_generation_server.layers.marlin import (
                GPTQMarlin24Weight,
                MarlinWeight,
                repack_gptq_for_marlin,
            )

            quant_method = getattr(self, "quant_method", "marlin")
            is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
            if is_marlin_24:
                try:
                    B = torch.cat(
                        [self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
                    )
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight, make sure the model is already quantized"
                    )

                B_meta = torch.cat(
                    [self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
                )

                s = torch.cat(
                    [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
                )

                gptq_params = self._get_gptq_params()
                weight = GPTQMarlin24Weight(
                    B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
                )
            elif quant_method == "gptq":
                gptq_params = self._get_gptq_params()
                try:
                    qweight = torch.cat(
                        [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
                        dim=1,
                    )
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
                    )

                scales = torch.cat(
                    [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
                )
                w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
                for w2 in w[1:]:
                    torch.testing.assert_close(w2, w[0])
                g_idx = w[0]

                weight = repack_gptq_for_marlin(
                    qweight=qweight,
                    scales=scales,
                    g_idx=g_idx,
                    bits=gptq_params.bits,
                    desc_act=gptq_params.desc_act,
                    groupsize=gptq_params.groupsize,
                    sym=gptq_params.sym,
                    sharded_infeatures=False,
                )
            else:
                try:
                    B = torch.cat(
                        [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
                    )
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight, make sure the model is already quantized"
                    )
                s = torch.cat(
                    [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
                )

                weight = MarlinWeight(B=B, s=s)

        else:
            w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
            weight = torch.cat(w, dim=dim)

        return weight

    def get_tensor_shard(self, var, dim):
        world_size = self.process_group.size()
        rank = self.process_group.rank()
        block_size = var.size()[dim] // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        if dim == 0:
            tensor = var[start:stop]
        elif dim == 1:
            tensor = var[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
        return tensor

    def get_multi_weights_row(self, prefix: str, quantize: str):
        if quantize == "exl2":
            from text_generation_server.layers.exl2 import Exl2Weight

            try:
                q_weight = self.get_tensor(f"{prefix}.q_weight")
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
                )

            q_scale = self.get_tensor(f"{prefix}.q_scale")
            q_invperm = self.get_tensor(f"{prefix}.q_invperm")
            q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
            q_groups = self.get_tensor(f"{prefix}.q_groups")

            return Exl2Weight(
                q_weight=q_weight,
                q_scale=q_scale,
                q_invperm=q_invperm,
                q_scale_max=q_scale_max,
                q_groups=q_groups,
            )

        elif quantize == "gptq":
            use_exllama = True
            gptq_params = self._get_gptq_params()

            if gptq_params.bits != 4:
                use_exllama = False

            if gptq_params.desc_act:
                log_once(logger.warning, "Disabling exllama because desc_act=True")
                use_exllama = False

            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                )

            if gptq_params.quant_method == "gptq":
                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
            elif gptq_params.quant_method == "awq":
                g_idx = None

            if self.process_group.size() > 1:
                if g_idx is not None:
                    if (
                        not torch.equal(
                            g_idx.cpu(),
                            torch.tensor(
                                [
                                    i // gptq_params.groupsize
                                    for i in range(g_idx.shape[0])
                                ],
                                dtype=torch.int32,
                            ),
                        )
                        and not (g_idx == 0).all()
                    ):
                        # Exllama implementation does not support row tensor parallelism with act-order, as
                        # it would require to reorder input activations that are split unto several GPUs
                        use_exllama = False

            from text_generation_server.layers.gptq import (
                HAS_EXLLAMA,
                CAN_EXLLAMA,
                GPTQWeight,
            )

            if use_exllama:
                if not HAS_EXLLAMA:
                    if CAN_EXLLAMA:
                        log_once(
                            logger.warning,
                            "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
                        )
                    use_exllama = False
                else:
                    log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")

            if use_exllama and gptq_params.groupsize != -1:
                qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
                scales = self.get_sharded(f"{prefix}.scales", dim=0)
            else:
                qzeros = self.get_tensor(f"{prefix}.qzeros")
                scales = self.get_tensor(f"{prefix}.scales")

            if use_exllama and g_idx is not None:
                g_idx = g_idx - g_idx[0]

            if gptq_params.quant_method == "awq":
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
                from text_generation_server.layers.awq.conversion_utils import (
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                if use_exllama:
                    g_idx = None
                else:
                    g_idx = (
                        torch.arange(
                            qweight.shape[0] * (32 // gptq_params.bits),
                            device=qweight.device,
                        )
                        // gptq_params.groupsize
                    ).to(dtype=torch.int32)

            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
                use_exllama=use_exllama,
            )
        elif quantize == "awq":
            from text_generation_server.layers.gptq import GPTQWeight

            gptq_params = self._get_gptq_params()

            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `awq` weight, make sure the model is already quantized"
                )

            qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
            scales = self.get_sharded(f"{prefix}.scales", dim=0)
            g_idx = None
            use_exllama = False

            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
                use_exllama=use_exllama,
            )
        elif quantize == "marlin":
            from text_generation_server.layers.gptq import GPTQWeight
            from text_generation_server.layers.marlin import (
                GPTQMarlin24Weight,
                MarlinWeight,
                repack_gptq_for_marlin,
            )

            quant_method = getattr(self, "quant_method", "marlin")
            is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
            if is_marlin_24:
                try:
                    B = self.get_sharded(f"{prefix}.B_24", dim=0)
                except RuntimeError:
                    raise RuntimeError(
                        "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
                    )

                B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0)
                num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
                if num_groups == 1:
                    # The number of groups is 1 when groupsize == -1. share
                    # scales between all shards in this case.
                    s = self.get_tensor(f"{prefix}.s")
                else:
                    s = self.get_sharded(f"{prefix}.s", dim=0)

                gptq_params = self._get_gptq_params()
                weight = GPTQMarlin24Weight(
                    B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
                )
            elif quant_method == "gptq":
                log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
                gptq_params = self._get_gptq_params()

                try:
                    qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
                    )

                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
                if gptq_params.desc_act or gptq_params.groupsize == -1:
                    scales = self.get_tensor(f"{prefix}.scales")
                else:
                    scales = self.get_sharded(f"{prefix}.scales", dim=0)

                sharded_in_features = self.process_group.size() > 1

                weight = repack_gptq_for_marlin(
                    qweight=qweight,
                    scales=scales,
                    g_idx=g_idx,
                    bits=gptq_params.bits,
                    desc_act=gptq_params.desc_act,
                    groupsize=gptq_params.groupsize,
                    sym=gptq_params.sym,
                    sharded_infeatures=sharded_in_features,
                )
            else:
                try:
                    B = self.get_sharded(f"{prefix}.B", dim=0)
                except RuntimeError:
                    raise RuntimeError(
                        "Cannot load `marlin` weight, make sure the model is already quantized."
                    )

                num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
                if num_groups == 1:
                    # The number of groups is 1 when groupsize == -1. share
                    # scales between all shards in this case.
                    s = self.get_tensor(f"{prefix}.s")
                else:
                    s = self.get_sharded(f"{prefix}.s", dim=0)
                weight = MarlinWeight(B=B, s=s)

        else:
            weight = self.get_sharded(f"{prefix}.weight", dim=1)
        return weight

    def _get_gptq_params(self) -> _GPTQParams:
        try:
            bits = self.get_tensor("gptq_bits").item()
            groupsize = self.get_tensor("gptq_groupsize").item()
            checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
            desc_act = False
            sym = True
            quant_method = "gptq"
        except (SafetensorError, RuntimeError) as e:
            try:
                bits = self.gptq_bits
                groupsize = self.gptq_groupsize
                checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
                desc_act = getattr(self, "gptq_desc_act", False)
                quant_method = getattr(self, "quant_method", "gptq")
                sym = getattr(self, "sym", True)
            except Exception:
                raise e

        return _GPTQParams(
            bits=bits,
            checkpoint_format=checkpoint_format,
            desc_act=desc_act,
            groupsize=groupsize,
            quant_method=quant_method,
            sym=sym,
        )

    def _set_gptq_params(self, model_id, revision):
        filename = "config.json"
        try:
            if os.path.exists(os.path.join(model_id, filename)):
                filename = os.path.join(model_id, filename)
            else:
                filename = hf_hub_download(
                    model_id, filename=filename, revision=revision
                )
            with open(filename, "r") as f:
                data = json.load(f)
            self.gptq_bits = data["quantization_config"]["bits"]
            self.gptq_groupsize = data["quantization_config"]["group_size"]
            # Order is important here, desc_act is missing on some real models
            self.quant_method = data["quantization_config"]["quant_method"]
            self.gptq_checkpoint_format = data["quantization_config"].get(
                "checkpoint_format"
            )
            self.gptq_sym = data["quantization_config"]["sym"]
            self.gptq_desc_act = data["quantization_config"]["desc_act"]
        except Exception:
            filename = "quantize_config.json"
            try:
                if os.path.exists(os.path.join(model_id, filename)):
                    filename = os.path.join(model_id, filename)
                else:
                    filename = hf_hub_download(
                        model_id, filename=filename, revision=revision
                    )
                with open(filename, "r") as f:
                    data = json.load(f)
                self.gptq_bits = data["bits"]
                self.gptq_groupsize = data["group_size"]
                self.gptq_sym = data["sym"]
                self.gptq_desc_act = data["desc_act"]
                if "version" in data and data["version"] == "GEMM":
                    self.quant_method = "awq"
            except Exception:
                filename = "quant_config.json"
                try:
                    if os.path.exists(os.path.join(model_id, filename)):
                        filename = os.path.join(model_id, filename)
                    else:
                        filename = hf_hub_download(
                            model_id, filename=filename, revision=revision
                        )
                    with open(filename, "r") as f:
                        data = json.load(f)
                    self.gptq_bits = data["w_bit"]
                    self.gptq_groupsize = data["q_group_size"]
                    self.gptq_desc_act = data["desc_act"]
                    if "version" in data and data["version"] == "GEMM":
                        self.quant_method = "awq"
                except Exception:
                    pass


def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
    """
    Convert block count or proportions to block sizes.

    This function accepts

    - The number of blocks (int), in which case the block size is
      total_size//blocks; or
    - A list of block sizes (List[int]).

    In the latter case, if sum(blocks) < total_size, the ratios between
    the block sizes will be preserved. For instance, if blocks is
    [2, 1, 1] and total_size is 1024, the returned block sizes are
    [512, 256, 256].
    """
    if isinstance(blocks, list):
        total_blocks = sum(blocks)
        assert (
            total_size % total_blocks == 0
        ), f"Cannot split {total_size} in proportional blocks: {blocks}"
        part_size = total_size // total_blocks
        return [part_size * block for block in blocks]
    else:
        assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
        single_size = total_size // blocks
        return [single_size] * blocks