import time
import torch.nn as nn
import math
import json
import os
import torch
import transformers

from texttable import Texttable
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
from huggingface_hub import HfApi
from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files
from text_generation_server.utils.gptq.quant_linear import QuantLinear
from loguru import logger
from typing import Optional

DEV = torch.device("cuda:0")


class Quantizer(nn.Module):
    def __init__(self, shape=1):
        super(Quantizer, self).__init__()
        self.register_buffer("maxq", torch.tensor(0))
        self.register_buffer("scale", torch.zeros(shape))
        self.register_buffer("zero", torch.zeros(shape))

    def configure(
        self,
        bits,
        perchannel=False,
        sym=True,
        mse=False,
        norm=2.4,
        grid=100,
        maxshrink=0.8,
        trits=False,
    ):
        self.maxq = torch.tensor(2**bits - 1)
        self.perchannel = perchannel
        self.sym = sym
        self.mse = mse
        self.norm = norm
        self.grid = grid
        self.maxshrink = maxshrink
        if trits:
            self.maxq = torch.tensor(-1)
        self.scale = torch.zeros_like(self.scale)

    def _quantize(self, x, scale, zero, maxq):
        if maxq < 0:
            return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
        q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
        return scale * (q - zero)

    def find_params(self, x, weight=False):
        dev = x.device
        self.maxq = self.maxq.to(dev)

        shape = x.shape
        if self.perchannel:
            if weight:
                x = x.flatten(1)
            else:
                if len(shape) == 4:
                    x = x.permute([1, 0, 2, 3])
                    x = x.flatten(1)
                if len(shape) == 3:
                    x = x.reshape((-1, shape[-1])).t()
                if len(shape) == 2:
                    x = x.t()
        else:
            x = x.flatten().unsqueeze(0)

        tmp = torch.zeros(x.shape[0], device=dev)
        xmin = torch.minimum(x.min(1)[0], tmp)
        xmax = torch.maximum(x.max(1)[0], tmp)

        if self.sym:
            xmax = torch.maximum(torch.abs(xmin), xmax)
            tmp = xmin < 0
            if torch.any(tmp):
                xmin[tmp] = -xmax[tmp]
        tmp = (xmin == 0) & (xmax == 0)
        xmin[tmp] = -1
        xmax[tmp] = +1

        if self.maxq < 0:
            self.scale = xmax
            self.zero = xmin
        else:
            self.scale = (xmax - xmin) / self.maxq
            if self.sym:
                self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
            else:
                self.zero = torch.round(-xmin / self.scale)

        if self.mse:
            best = torch.full([x.shape[0]], float("inf"), device=dev)
            for i in range(int(self.maxshrink * self.grid)):
                p = 1 - i / self.grid
                xmin1 = p * xmin
                xmax1 = p * xmax
                scale1 = (xmax1 - xmin1) / self.maxq
                zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
                q = self._quantize(
                    x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
                )
                q -= x
                q.abs_()
                q.pow_(self.norm)
                err = torch.sum(q, 1)
                tmp = err < best
                if torch.any(tmp):
                    best[tmp] = err[tmp]
                    self.scale[tmp] = scale1[tmp]
                    self.zero[tmp] = zero1[tmp]
        if not self.perchannel:
            if weight:
                tmp = shape[0]
            else:
                tmp = shape[1] if len(shape) != 3 else shape[2]
            self.scale = self.scale.repeat(tmp)
            self.zero = self.zero.repeat(tmp)

        if weight:
            shape = [-1] + [1] * (len(shape) - 1)
            self.scale = self.scale.reshape(shape)
            self.zero = self.zero.reshape(shape)
            return
        if len(shape) == 4:
            self.scale = self.scale.reshape((1, -1, 1, 1))
            self.zero = self.zero.reshape((1, -1, 1, 1))
        if len(shape) == 3:
            self.scale = self.scale.reshape((1, 1, -1))
            self.zero = self.zero.reshape((1, 1, -1))
        if len(shape) == 2:
            self.scale = self.scale.unsqueeze(0)
            self.zero = self.zero.unsqueeze(0)

    def quantize(self, x):
        if self.ready():
            return self._quantize(x, self.scale, self.zero, self.maxq)

        return x

    def enabled(self):
        return self.maxq > 0

    def ready(self):
        return torch.all(self.scale != 0)


class GPTQ:
    def __init__(self, layer, observe=False):
        self.layer = layer
        self.dev = self.layer.weight.device
        W = layer.weight.data.clone()
        if isinstance(self.layer, nn.Conv2d):
            W = W.flatten(1)
        if isinstance(self.layer, transformers.Conv1D):
            W = W.t()
        self.rows = W.shape[0]
        self.columns = W.shape[1]
        self.H = torch.zeros((self.columns, self.columns), device=self.dev)
        self.nsamples = 0
        self.quantizer = Quantizer()
        self.observe = observe

    def add_batch(self, inp, out):
        # Hessian H = 2 X XT + λ I
        if self.observe:
            self.inp1 = inp
            self.out1 = out
        else:
            self.inp1 = None
            self.out1 = None

        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if isinstance(self.layer, nn.Linear) or isinstance(
            self.layer, transformers.Conv1D
        ):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()
        if isinstance(self.layer, nn.Conv2d):
            unfold = nn.Unfold(
                self.layer.kernel_size,
                dilation=self.layer.dilation,
                padding=self.layer.padding,
                stride=self.layer.stride,
            )
            inp = unfold(inp)
            inp = inp.permute([1, 0, 2])
            inp = inp.flatten(1)
        self.H *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp
        # inp = inp.float()
        inp = math.sqrt(2 / self.nsamples) * inp.float()
        # self.H += 2 / self.nsamples * inp.matmul(inp.t())
        self.H += inp.matmul(inp.t())

    def print_loss(self, name, q_weight, weight_error, timecost):
        table = Texttable()
        length = 28
        name = (
            (name + " " * (length - len(name)))
            if len(name) <= length
            else name[:length]
        )

        table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])

        # assign weight
        self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
            self.layer.weight.data.dtype
        )

        if self.inp1 is not None:
            # quantize input to int8
            quantizer = Quantizer()
            quantizer.configure(8, perchannel=False, sym=True, mse=False)
            quantizer.find_params(self.inp1)
            q_in = quantizer.quantize(self.inp1).type(torch.float16)
            q_out = self.layer(q_in)

            # get kinds of SNR
            q_SNR = torch_snr_error(q_out, self.out1).item()
            fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
        else:
            q_SNR = "-"
            fp_SNR = "-"

        table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
        print(table.draw().split("\n")[-2])

    def fasterquant(
        self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=""
    ):
        self.layer.to(self.dev)

        W = self.layer.weight.data.clone()
        if isinstance(self.layer, nn.Conv2d):
            W = W.flatten(1)
        if isinstance(self.layer, transformers.Conv1D):
            W = W.t()
        W = W.float()

        tick = time.time()

        if not self.quantizer.ready():
            self.quantizer.find_params(W, weight=True)

        H = self.H
        if not self.observe:
            del self.H
        dead = torch.diag(H) == 0
        H[dead, dead] = 1
        W[:, dead] = 0

        if act_order:
            perm = torch.argsort(torch.diag(H), descending=True)
            W = W[:, perm]
            H = H[perm][:, perm]

        Losses = torch.zeros_like(W)
        Q = torch.zeros_like(W)

        damp = percdamp * torch.mean(torch.diag(H))
        diag = torch.arange(self.columns, device=self.dev)
        H[diag, diag] += damp
        H = torch.linalg.cholesky(H)
        H = torch.cholesky_inverse(H)
        try:
            H = torch.linalg.cholesky(H, upper=True)
        except Exception:
            # Addition because Falcon fails on h_to_4h
            H = torch.linalg.cholesky(
                H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True
            )
        Hinv = H

        g_idx = []
        scale = []
        zero = []
        now_idx = 1

        for i1 in range(0, self.columns, blocksize):
            i2 = min(i1 + blocksize, self.columns)
            count = i2 - i1

            W1 = W[:, i1:i2].clone()
            Q1 = torch.zeros_like(W1)
            Err1 = torch.zeros_like(W1)
            Losses1 = torch.zeros_like(W1)
            Hinv1 = Hinv[i1:i2, i1:i2]

            for i in range(count):
                w = W1[:, i]
                d = Hinv1[i, i]

                if groupsize != -1:
                    if (i1 + i) % groupsize == 0:
                        self.quantizer.find_params(
                            W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
                        )

                    if ((i1 + i) // groupsize) - now_idx == -1:
                        scale.append(self.quantizer.scale)
                        zero.append(self.quantizer.zero)
                        now_idx += 1

                q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
                Q1[:, i] = q
                Losses1[:, i] = (w - q) ** 2 / d**2

                err1 = (w - q) / d
                W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
                Err1[:, i] = err1

            Q[:, i1:i2] = Q1
            Losses[:, i1:i2] = Losses1 / 2

            W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

        torch.cuda.synchronize()
        error = torch.sum(Losses).item()

        groupsize = groupsize if groupsize != -1 else self.columns
        g_idx = [i // groupsize for i in range(self.columns)]
        g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
        if act_order:
            invperm = torch.argsort(perm)
            Q = Q[:, invperm]
            g_idx = g_idx[invperm]

        if isinstance(self.layer, transformers.Conv1D):
            Q = Q.t()

        self.print_loss(
            name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
        )

        if scale == []:
            scale.append(self.quantizer.scale)
            zero.append(self.quantizer.zero)
        scale = torch.cat(scale, dim=1)
        zero = torch.cat(zero, dim=1)
        return scale, zero, g_idx, error

    def free(self):
        self.inp1 = None
        self.out1 = None
        self.H = None
        self.Losses = None
        self.Trace = None
        torch.cuda.empty_cache()


def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):
    from datasets import load_dataset

    traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_id, use_fast=False, trust_remote_code=trust_remote_code
        )
    except:
        tokenizer = AutoTokenizer.from_pretrained(
            model_id, use_fast=True, trust_remote_code=trust_remote_code
        )

    trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
    testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")

    import random

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc


def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):
    from datasets import load_dataset

    traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
    valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")

    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_id, use_fast=False, trust_remote_code=trust_remote_code
        )
    except:
        tokenizer = AutoTokenizer.from_pretrained(
            model_id, use_fast=True, trust_remote_code=trust_remote_code
        )

    trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
    testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")

    import random

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc


def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):
    from datasets import load_dataset

    traindata = load_dataset(
        "allenai/c4",
        "allenai--c4",
        data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
        split="train",
        use_auth_token=False,
    )
    valdata = load_dataset(
        "allenai/c4",
        "allenai--c4",
        data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
        split="validation",
        use_auth_token=False,
    )

    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_id, use_fast=False, trust_remote_code=trust_remote_code
        )
    except:
        tokenizer = AutoTokenizer.from_pretrained(
            model_id, use_fast=True, trust_remote_code=trust_remote_code
        )

    import random

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
            if trainenc.input_ids.shape[1] >= seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    import random

    random.seed(0)
    valenc = []
    for _ in range(256):
        while True:
            i = random.randint(0, len(valdata) - 1)
            tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
            if tmp.input_ids.shape[1] >= seqlen:
                break
        i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        valenc.append(tmp.input_ids[:, i:j])
    valenc = torch.hstack(valenc)

    class TokenizerWrapper:
        def __init__(self, input_ids):
            self.input_ids = input_ids

    valenc = TokenizerWrapper(valenc)

    return trainloader, valenc


def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):
    from datasets import load_dataset

    traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
    testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")

    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_id, use_fast=False, trust_remote_code=trust_remote_code
        )
    except:
        tokenizer = AutoTokenizer.from_pretrained(
            model_id, use_fast=True, trust_remote_code=trust_remote_code
        )

    trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
    testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")

    import random

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc


def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
    from datasets import load_dataset

    traindata = load_dataset(
        "allenai/c4",
        "allenai--c4",
        data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
        split="train",
    )
    valdata = load_dataset(
        "allenai/c4",
        "allenai--c4",
        data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
        split="validation",
    )

    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_id, use_fast=False, trust_remote_code=trust_remote_code
        )
    except:
        tokenizer = AutoTokenizer.from_pretrained(
            model_id, use_fast=True, trust_remote_code=trust_remote_code
        )

    import random

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
            if trainenc.input_ids.shape[1] >= seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
    valenc = valenc.input_ids[:, : (256 * seqlen)]

    class TokenizerWrapper:
        def __init__(self, input_ids):
            self.input_ids = input_ids

    valenc = TokenizerWrapper(valenc)

    return trainloader, valenc


def get_loaders(
    name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False
):
    if "wikitext2" in name:
        return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)
    if "ptb" in name:
        if "new" in name:
            return get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code)
        return get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code)
    if "c4" in name:
        if "new" in name:
            return get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code)
        return get_c4(nsamples, seed, seqlen, model_id, trust_remote_code)


def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""):
    # Skip last lm_head linear
    # Need isintance Falcon is inheriting Linear.
    if isinstance(module, layers) and "lm_head" not in name:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(
            find_layers(
                child, layers=layers, name=name + "." + name1 if name != "" else name1
            )
        )
    return res


@torch.no_grad()
def sequential(
    model,
    dataloader,
    dev,
    nsamples,
    bits,
    groupsize,
    *,
    hooks,
    percdamp=0.01,
    sym: bool = False,
    act_order: bool = False,
):
    print("Starting ...")

    use_cache = model.config.use_cache
    model.config.use_cache = False
    try:
        layers = model.model.layers
        prefix = "model.layers"
    except Exception:
        layers = model.transformer.h
        prefix = "transformer.h"

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )

    cache = {"i": 0}
    extra = {}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            inps[cache["i"]] = inp
            cache["i"] += 1
            extra.update(kwargs.copy())
            raise ValueError

    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].cuda())
        except ValueError:
            pass
    layers[0] = layers[0].module

    # layers[0] = layers[0].cpu()
    # model.model.embed_tokens = model.model.embed_tokens.cpu()
    # model.model.norm = model.model.norm.cpu()
    torch.cuda.empty_cache()
    for hook in hooks:
        hook.remove()

    outs = torch.zeros_like(inps)

    extra = {
        k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items()
    }

    print("Ready.")

    quantizers = {}
    for i in range(len(layers)):
        print(f"Quantizing layer {i+1}/{len(layers)}..")
        print("+------------------+--------------+------------+-----------+-------+")
        print("|       name       | weight_error | fp_inp_SNR | q_inp_SNR | time  |")
        print("+==================+==============+============+===========+=======+")

        layer = layers[i]
        layer.load()
        full = find_layers(layer)
        sequential = [list(full.keys())]

        for names in sequential:
            subset = {n: full[n] for n in names}
            gptq = {}
            for name in subset:
                gptq[name] = GPTQ(subset[name])
                gptq[name].quantizer.configure(
                    bits, perchannel=True, sym=sym, mse=False
                )
                pass

            def add_batch(name):
                def tmp(_, inp, out):
                    gptq[name].add_batch(inp[0].data, out.data)

                return tmp

            handles = []
            for name in subset:
                handles.append(subset[name].register_forward_hook(add_batch(name)))
            for j in range(nsamples):
                outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
            for h in handles:
                h.remove()

            for name in subset:
                scale, zero, g_idx, error = gptq[name].fasterquant(
                    percdamp=percdamp,
                    groupsize=groupsize,
                    act_order=act_order,
                    name=name,
                )
                quantizers[f"{prefix}.{i}.{name}"] = (
                    gptq[name].quantizer.cpu(),
                    scale.cpu(),
                    zero.cpu(),
                    g_idx.cpu(),
                    bits,
                    groupsize,
                )

                gptq[name].free()

        for j in range(nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]

        layer.unload()
        del layer
        del gptq
        torch.cuda.empty_cache()

        inps, outs = outs, inps
        print("+------------------+--------------+------------+-----------+-------+")
        print("\n")

    model.config.use_cache = use_cache

    return quantizers


def make_quant_linear(module, names, bits, groupsize, name=""):
    if isinstance(module, QuantLinear):
        return
    for attr in dir(module):
        tmp = getattr(module, attr)
        name1 = name + "." + attr if name != "" else attr
        if name1 in names:
            delattr(module, attr)
            setattr(
                module,
                attr,
                QuantLinear.new(
                    bits,
                    groupsize,
                    tmp.in_features,
                    tmp.out_features,
                    tmp.bias is not None,
                ),
            )
    for name1, child in module.named_children():
        make_quant_linear(
            child, names, bits, groupsize, name + "." + name1 if name != "" else name1
        )


# TODO: perform packing on GPU
def pack(model, quantizers, bits, groupsize):
    layers = find_layers(model)
    layers = {n: layers[n] for n in quantizers}
    make_quant_linear(model, quantizers, bits, groupsize)
    qlayers = find_layers(model, (QuantLinear,))
    print("Packing ...")
    for name in qlayers:
        print(name)
        quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
        qlayers[name].pack(layers[name], scale, zero, g_idx)
    print("Done.")
    return model


def setdeepattr(module, full_name, tensor):
    current = module
    tokens = full_name.split(".")
    for token in tokens[:-1]:
        current = getattr(current, token)
    setattr(current, tokens[-1], tensor)


def getdeepattr(module, full_name):
    current = module
    tokens = full_name.split(".")
    for token in tokens:
        current = getattr(current, token)
    return current


def load_weights_pre_hook(module_name, weights, recursive=False):
    def inner(module, args):
        print(f"Pre hook {module_name}")
        local_params = {}
        for k, v in module.named_parameters():
            if not recursive and k.count(".") != 1:
                continue
            local_params[k] = v
        for k, v in module.named_buffers():
            if not recursive and k.count(".") != 1:
                continue
            local_params[k] = v

        for local_param in local_params:
            current_tensor = getdeepattr(module, local_param)
            if current_tensor.device == torch.device("meta"):
                # print(f"Loading {local_param}")
                if module_name:
                    tensor_name = f"{module_name}.{local_param}"
                else:
                    tensor_name = local_param
                tensor = weights.get_tensor(tensor_name)
                setdeepattr(module, local_param, nn.Parameter(tensor))
            else:
                tensor = current_tensor.to(device=torch.device("cuda:0"))
                if current_tensor.requires_grad:
                    tensor = nn.Parameter(tensor)
                setdeepattr(module, local_param, tensor)

    return inner


def load_weights_post_hook(module_name, weights, recursive=False):
    def inner(module, args, output):
        print(f"Post hook {module_name}")
        local_params = {}
        for k, v in module.named_parameters():
            if not recursive and k.count(".") != 1:
                continue
            local_params[k] = v
        for k, v in module.named_buffers():
            if not recursive and k.count(".") != 1:
                continue
            local_params[k] = v
        for local_param in local_params:
            # print(f"Unloading {local_param}")
            current_tensor = getdeepattr(module, local_param)
            setdeepattr(
                module,
                local_param,
                nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
            )
        return output

    return inner


def quantize(
    model_id: str,
    bits: int,
    groupsize: int,
    output_dir: str,
    revision: str,
    trust_remote_code: bool,
    upload_to_model_id: Optional[str],
    percdamp: float,
    act_order: bool,
):
    print("loading model")
    config = AutoConfig.from_pretrained(
        model_id,
        trust_remote_code=trust_remote_code,
    )

    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(
            config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
        )
    model = model.eval()

    print("LOADED model")
    files = weight_files(model_id, revision, extension=".safetensors")
    process_group, _, _ = initialize_torch_distributed()
    weights = Weights(
        files,
        device=torch.device("cuda:0"),
        dtype=torch.float16,
        process_group=process_group,
        aliases={"embed_tokens.weight": ["lm_head.weight"]},
    )
    hooks = []
    for name, module in model.named_modules():

        def load(module, name):
            def _load():
                load_weights_pre_hook(name, weights, recursive=True)(module, None)

            return _load

        def unload(module, name):
            def _unload():
                load_weights_post_hook(name, weights, recursive=True)(
                    module, None, None
                )

            return _unload

        module.load = load(module, name)
        module.unload = unload(module, name)
        hooks.append(
            module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
        )
        hooks.append(
            module.register_forward_hook(load_weights_post_hook(name, weights))
        )
    model.seqlen = 2048

    dataset = "wikitext2"
    nsamples = 128
    seed = None

    dataloader, testloader = get_loaders(
        dataset,
        nsamples=nsamples,
        seed=seed,
        model_id=model_id,
        seqlen=model.seqlen,
        trust_remote_code=trust_remote_code,
    )

    tick = time.time()
    quantizers = sequential(
        model,
        dataloader,
        DEV,
        nsamples,
        bits,
        groupsize,
        percdamp=percdamp,
        act_order=act_order,
        hooks=hooks,
    )
    print(time.time() - tick)

    pack(model, quantizers, bits, groupsize)
    from safetensors.torch import save_file
    from transformers.modeling_utils import shard_checkpoint

    state_dict = model.state_dict()
    state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
    state_dict["gptq_bits"] = torch.LongTensor([bits])
    state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])

    max_shard_size = "10GB"
    shards, index = shard_checkpoint(
        state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
    )
    os.makedirs(output_dir, exist_ok=True)
    for shard_file, shard in shards.items():
        save_file(
            shard,
            os.path.join(output_dir, shard_file),
            metadata={
                "format": "pt",
                "quantized": "gptq",
                "origin": "text-generation-inference",
            },
        )
    if index is None:
        path_to_weights = os.path.join(output_dir, "model.safetensors")
        logger.info(f"Model weights saved in {path_to_weights}")
    else:
        save_index_file = "model.safetensors.index.json"
        save_index_file = os.path.join(output_dir, save_index_file)
        with open(save_index_file, "w", encoding="utf-8") as f:
            content = json.dumps(index, indent=2, sort_keys=True) + "\n"
            f.write(content)
        logger.info(
            f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
            f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
            f"index located at {save_index_file}."
        )
    config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
    config.save_pretrained(output_dir)
    logger.info("Saved config")
    logger.info("Saving tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(
        model_id, trust_remote_code=trust_remote_code
    )
    tokenizer.save_pretrained(output_dir)
    logger.info("Saved tokenizer")

    if upload_to_model_id:
        api = HfApi()

        api.upload_folder(
            folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model"
        )