mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
but should work on more configurations (no need for 2 GPUs, less RAM usage). # What does this PR do? Reworking the quantization script so it's still universal (not llama specific) but should work on more configurations (no need for 2 GPUs, less RAM usage). Still need to investigate the potential differences in quantization results. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
983 lines
30 KiB
Python
983 lines
30 KiB
Python
import argparse
|
|
import time
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
import json
|
|
import os
|
|
|
|
from texttable import Texttable
|
|
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
|
|
import transformers
|
|
from huggingface_hub import HfApi
|
|
import numpy as np
|
|
import torch
|
|
from accelerate import init_empty_weights
|
|
from text_generation_server.utils import initialize_torch_distributed, Weights
|
|
from text_generation_server.utils.hub import weight_files
|
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
|
from loguru import logger
|
|
from typing import Optional
|
|
|
|
DEV = torch.device("cuda:0")
|
|
|
|
|
|
class Quantizer(nn.Module):
|
|
def __init__(self, shape=1):
|
|
super(Quantizer, self).__init__()
|
|
self.register_buffer("maxq", torch.tensor(0))
|
|
self.register_buffer("scale", torch.zeros(shape))
|
|
self.register_buffer("zero", torch.zeros(shape))
|
|
|
|
def configure(
|
|
self,
|
|
bits,
|
|
perchannel=False,
|
|
sym=True,
|
|
mse=False,
|
|
norm=2.4,
|
|
grid=100,
|
|
maxshrink=0.8,
|
|
trits=False,
|
|
):
|
|
self.maxq = torch.tensor(2**bits - 1)
|
|
self.perchannel = perchannel
|
|
self.sym = sym
|
|
self.mse = mse
|
|
self.norm = norm
|
|
self.grid = grid
|
|
self.maxshrink = maxshrink
|
|
if trits:
|
|
self.maxq = torch.tensor(-1)
|
|
self.scale = torch.zeros_like(self.scale)
|
|
|
|
def _quantize(self, x, scale, zero, maxq):
|
|
if maxq < 0:
|
|
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
|
|
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
|
return scale * (q - zero)
|
|
|
|
def find_params(self, x, weight=False):
|
|
dev = x.device
|
|
self.maxq = self.maxq.to(dev)
|
|
|
|
shape = x.shape
|
|
if self.perchannel:
|
|
if weight:
|
|
x = x.flatten(1)
|
|
else:
|
|
if len(shape) == 4:
|
|
x = x.permute([1, 0, 2, 3])
|
|
x = x.flatten(1)
|
|
if len(shape) == 3:
|
|
x = x.reshape((-1, shape[-1])).t()
|
|
if len(shape) == 2:
|
|
x = x.t()
|
|
else:
|
|
x = x.flatten().unsqueeze(0)
|
|
|
|
tmp = torch.zeros(x.shape[0], device=dev)
|
|
xmin = torch.minimum(x.min(1)[0], tmp)
|
|
xmax = torch.maximum(x.max(1)[0], tmp)
|
|
|
|
if self.sym:
|
|
xmax = torch.maximum(torch.abs(xmin), xmax)
|
|
tmp = xmin < 0
|
|
if torch.any(tmp):
|
|
xmin[tmp] = -xmax[tmp]
|
|
tmp = (xmin == 0) & (xmax == 0)
|
|
xmin[tmp] = -1
|
|
xmax[tmp] = +1
|
|
|
|
if self.maxq < 0:
|
|
self.scale = xmax
|
|
self.zero = xmin
|
|
else:
|
|
self.scale = (xmax - xmin) / self.maxq
|
|
if self.sym:
|
|
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
|
|
else:
|
|
self.zero = torch.round(-xmin / self.scale)
|
|
|
|
if self.mse:
|
|
best = torch.full([x.shape[0]], float("inf"), device=dev)
|
|
for i in range(int(self.maxshrink * self.grid)):
|
|
p = 1 - i / self.grid
|
|
xmin1 = p * xmin
|
|
xmax1 = p * xmax
|
|
scale1 = (xmax1 - xmin1) / self.maxq
|
|
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
|
q = self._quantize(
|
|
x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
|
|
)
|
|
q -= x
|
|
q.abs_()
|
|
q.pow_(self.norm)
|
|
err = torch.sum(q, 1)
|
|
tmp = err < best
|
|
if torch.any(tmp):
|
|
best[tmp] = err[tmp]
|
|
self.scale[tmp] = scale1[tmp]
|
|
self.zero[tmp] = zero1[tmp]
|
|
if not self.perchannel:
|
|
if weight:
|
|
tmp = shape[0]
|
|
else:
|
|
tmp = shape[1] if len(shape) != 3 else shape[2]
|
|
self.scale = self.scale.repeat(tmp)
|
|
self.zero = self.zero.repeat(tmp)
|
|
|
|
if weight:
|
|
shape = [-1] + [1] * (len(shape) - 1)
|
|
self.scale = self.scale.reshape(shape)
|
|
self.zero = self.zero.reshape(shape)
|
|
return
|
|
if len(shape) == 4:
|
|
self.scale = self.scale.reshape((1, -1, 1, 1))
|
|
self.zero = self.zero.reshape((1, -1, 1, 1))
|
|
if len(shape) == 3:
|
|
self.scale = self.scale.reshape((1, 1, -1))
|
|
self.zero = self.zero.reshape((1, 1, -1))
|
|
if len(shape) == 2:
|
|
self.scale = self.scale.unsqueeze(0)
|
|
self.zero = self.zero.unsqueeze(0)
|
|
|
|
def quantize(self, x):
|
|
if self.ready():
|
|
return self._quantize(x, self.scale, self.zero, self.maxq)
|
|
|
|
return x
|
|
|
|
def enabled(self):
|
|
return self.maxq > 0
|
|
|
|
def ready(self):
|
|
return torch.all(self.scale != 0)
|
|
|
|
|
|
class GPTQ:
|
|
def __init__(self, layer, observe=False):
|
|
self.layer = layer
|
|
self.dev = self.layer.weight.device
|
|
W = layer.weight.data.clone()
|
|
if isinstance(self.layer, nn.Conv2d):
|
|
W = W.flatten(1)
|
|
if isinstance(self.layer, transformers.Conv1D):
|
|
W = W.t()
|
|
self.rows = W.shape[0]
|
|
self.columns = W.shape[1]
|
|
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
|
|
self.nsamples = 0
|
|
self.quantizer = Quantizer()
|
|
self.observe = observe
|
|
|
|
def add_batch(self, inp, out):
|
|
# Hessian H = 2 X XT + λ I
|
|
if self.observe:
|
|
self.inp1 = inp
|
|
self.out1 = out
|
|
else:
|
|
self.inp1 = None
|
|
self.out1 = None
|
|
|
|
if len(inp.shape) == 2:
|
|
inp = inp.unsqueeze(0)
|
|
tmp = inp.shape[0]
|
|
if isinstance(self.layer, nn.Linear) or isinstance(
|
|
self.layer, transformers.Conv1D
|
|
):
|
|
if len(inp.shape) == 3:
|
|
inp = inp.reshape((-1, inp.shape[-1]))
|
|
inp = inp.t()
|
|
if isinstance(self.layer, nn.Conv2d):
|
|
unfold = nn.Unfold(
|
|
self.layer.kernel_size,
|
|
dilation=self.layer.dilation,
|
|
padding=self.layer.padding,
|
|
stride=self.layer.stride,
|
|
)
|
|
inp = unfold(inp)
|
|
inp = inp.permute([1, 0, 2])
|
|
inp = inp.flatten(1)
|
|
self.H *= self.nsamples / (self.nsamples + tmp)
|
|
self.nsamples += tmp
|
|
# inp = inp.float()
|
|
inp = math.sqrt(2 / self.nsamples) * inp.float()
|
|
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
|
|
self.H += inp.matmul(inp.t())
|
|
|
|
def print_loss(self, name, q_weight, weight_error, timecost):
|
|
table = Texttable()
|
|
length = 28
|
|
name = (
|
|
(name + " " * (length - len(name)))
|
|
if len(name) <= length
|
|
else name[:length]
|
|
)
|
|
|
|
table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])
|
|
|
|
# assign weight
|
|
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
|
|
self.layer.weight.data.dtype
|
|
)
|
|
|
|
if self.inp1 is not None:
|
|
# quantize input to int8
|
|
quantizer = Quantizer()
|
|
quantizer.configure(8, perchannel=False, sym=True, mse=False)
|
|
quantizer.find_params(self.inp1)
|
|
q_in = quantizer.quantize(self.inp1).type(torch.float16)
|
|
q_out = self.layer(q_in)
|
|
|
|
# get kinds of SNR
|
|
q_SNR = torch_snr_error(q_out, self.out1).item()
|
|
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
|
|
else:
|
|
q_SNR = "-"
|
|
fp_SNR = "-"
|
|
|
|
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
|
|
print(table.draw().split("\n")[-2])
|
|
|
|
def fasterquant(
|
|
self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=""
|
|
):
|
|
self.layer.to(self.dev)
|
|
|
|
W = self.layer.weight.data.clone()
|
|
if isinstance(self.layer, nn.Conv2d):
|
|
W = W.flatten(1)
|
|
if isinstance(self.layer, transformers.Conv1D):
|
|
W = W.t()
|
|
W = W.float()
|
|
|
|
tick = time.time()
|
|
|
|
if not self.quantizer.ready():
|
|
self.quantizer.find_params(W, weight=True)
|
|
|
|
H = self.H
|
|
if not self.observe:
|
|
del self.H
|
|
dead = torch.diag(H) == 0
|
|
H[dead, dead] = 1
|
|
W[:, dead] = 0
|
|
|
|
if act_order:
|
|
perm = torch.argsort(torch.diag(H), descending=True)
|
|
W = W[:, perm]
|
|
H = H[perm][:, perm]
|
|
|
|
Losses = torch.zeros_like(W)
|
|
Q = torch.zeros_like(W)
|
|
|
|
damp = percdamp * torch.mean(torch.diag(H))
|
|
diag = torch.arange(self.columns, device=self.dev)
|
|
H[diag, diag] += damp
|
|
H = torch.linalg.cholesky(H)
|
|
H = torch.cholesky_inverse(H)
|
|
try:
|
|
H = torch.linalg.cholesky(H, upper=True)
|
|
except Exception:
|
|
# Addition because Falcon fails on h_to_4h
|
|
H = torch.linalg.cholesky(
|
|
H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True
|
|
)
|
|
Hinv = H
|
|
|
|
g_idx = []
|
|
scale = []
|
|
zero = []
|
|
now_idx = 1
|
|
|
|
for i1 in range(0, self.columns, blocksize):
|
|
i2 = min(i1 + blocksize, self.columns)
|
|
count = i2 - i1
|
|
|
|
W1 = W[:, i1:i2].clone()
|
|
Q1 = torch.zeros_like(W1)
|
|
Err1 = torch.zeros_like(W1)
|
|
Losses1 = torch.zeros_like(W1)
|
|
Hinv1 = Hinv[i1:i2, i1:i2]
|
|
|
|
for i in range(count):
|
|
w = W1[:, i]
|
|
d = Hinv1[i, i]
|
|
|
|
if groupsize != -1:
|
|
if (i1 + i) % groupsize == 0:
|
|
self.quantizer.find_params(
|
|
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
|
|
)
|
|
|
|
if ((i1 + i) // groupsize) - now_idx == -1:
|
|
scale.append(self.quantizer.scale)
|
|
zero.append(self.quantizer.zero)
|
|
now_idx += 1
|
|
|
|
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
|
|
Q1[:, i] = q
|
|
Losses1[:, i] = (w - q) ** 2 / d**2
|
|
|
|
err1 = (w - q) / d
|
|
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
|
Err1[:, i] = err1
|
|
|
|
Q[:, i1:i2] = Q1
|
|
Losses[:, i1:i2] = Losses1 / 2
|
|
|
|
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
|
|
|
|
torch.cuda.synchronize()
|
|
error = torch.sum(Losses).item()
|
|
|
|
groupsize = groupsize if groupsize != -1 else self.columns
|
|
g_idx = [i // groupsize for i in range(self.columns)]
|
|
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
|
if act_order:
|
|
invperm = torch.argsort(perm)
|
|
Q = Q[:, invperm]
|
|
g_idx = g_idx[invperm]
|
|
|
|
if isinstance(self.layer, transformers.Conv1D):
|
|
Q = Q.t()
|
|
|
|
self.print_loss(
|
|
name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
|
|
)
|
|
|
|
if scale == []:
|
|
scale.append(self.quantizer.scale)
|
|
zero.append(self.quantizer.zero)
|
|
scale = torch.cat(scale, dim=1)
|
|
zero = torch.cat(zero, dim=1)
|
|
return scale, zero, g_idx, error
|
|
|
|
def free(self):
|
|
self.inp1 = None
|
|
self.out1 = None
|
|
self.H = None
|
|
self.Losses = None
|
|
self.Trace = None
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def get_wikitext2(nsamples, seed, seqlen, model_id):
|
|
from datasets import load_dataset
|
|
|
|
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
|
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
|
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
|
|
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
|
|
|
|
import random
|
|
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
return trainloader, testenc
|
|
|
|
|
|
def get_ptb(nsamples, seed, seqlen, model_id):
|
|
from datasets import load_dataset
|
|
|
|
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
|
valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
|
except:
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
|
trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
|
|
testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
|
|
|
|
import random
|
|
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
return trainloader, testenc
|
|
|
|
|
|
def get_c4(nsamples, seed, seqlen, model_id):
|
|
from datasets import load_dataset
|
|
|
|
traindata = load_dataset(
|
|
"allenai/c4",
|
|
"allenai--c4",
|
|
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
|
split="train",
|
|
use_auth_token=False,
|
|
)
|
|
valdata = load_dataset(
|
|
"allenai/c4",
|
|
"allenai--c4",
|
|
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
|
split="validation",
|
|
use_auth_token=False,
|
|
)
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
|
except:
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
|
|
|
import random
|
|
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
while True:
|
|
i = random.randint(0, len(traindata) - 1)
|
|
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
|
|
if trainenc.input_ids.shape[1] >= seqlen:
|
|
break
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
|
|
import random
|
|
|
|
random.seed(0)
|
|
valenc = []
|
|
for _ in range(256):
|
|
while True:
|
|
i = random.randint(0, len(valdata) - 1)
|
|
tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
|
|
if tmp.input_ids.shape[1] >= seqlen:
|
|
break
|
|
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
valenc.append(tmp.input_ids[:, i:j])
|
|
valenc = torch.hstack(valenc)
|
|
|
|
class TokenizerWrapper:
|
|
def __init__(self, input_ids):
|
|
self.input_ids = input_ids
|
|
|
|
valenc = TokenizerWrapper(valenc)
|
|
|
|
return trainloader, valenc
|
|
|
|
|
|
def get_ptb_new(nsamples, seed, seqlen, model_id):
|
|
from datasets import load_dataset
|
|
|
|
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
|
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
|
except:
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
|
trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
|
|
testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
|
|
|
|
import random
|
|
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
return trainloader, testenc
|
|
|
|
|
|
def get_c4_new(nsamples, seed, seqlen, model_id):
|
|
from datasets import load_dataset
|
|
|
|
traindata = load_dataset(
|
|
"allenai/c4",
|
|
"allenai--c4",
|
|
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
|
split="train",
|
|
)
|
|
valdata = load_dataset(
|
|
"allenai/c4",
|
|
"allenai--c4",
|
|
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
|
split="validation",
|
|
)
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
|
except:
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
|
|
|
import random
|
|
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
while True:
|
|
i = random.randint(0, len(traindata) - 1)
|
|
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
|
|
if trainenc.input_ids.shape[1] >= seqlen:
|
|
break
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
|
|
valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
|
|
valenc = valenc.input_ids[:, : (256 * seqlen)]
|
|
|
|
class TokenizerWrapper:
|
|
def __init__(self, input_ids):
|
|
self.input_ids = input_ids
|
|
|
|
valenc = TokenizerWrapper(valenc)
|
|
|
|
return trainloader, valenc
|
|
|
|
|
|
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""):
|
|
if "wikitext2" in name:
|
|
return get_wikitext2(nsamples, seed, seqlen, model_id)
|
|
if "ptb" in name:
|
|
if "new" in name:
|
|
return get_ptb_new(nsamples, seed, seqlen, model_id)
|
|
return get_ptb(nsamples, seed, seqlen, model_id)
|
|
if "c4" in name:
|
|
if "new" in name:
|
|
return get_c4_new(nsamples, seed, seqlen, model_id)
|
|
return get_c4(nsamples, seed, seqlen, model_id)
|
|
|
|
|
|
def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""):
|
|
# Skip last lm_head linear
|
|
# Need isintance Falcon is inheriting Linear.
|
|
if isinstance(module, layers) and "lm_head" not in name:
|
|
return {name: module}
|
|
res = {}
|
|
for name1, child in module.named_children():
|
|
res.update(
|
|
find_layers(
|
|
child, layers=layers, name=name + "." + name1 if name != "" else name1
|
|
)
|
|
)
|
|
return res
|
|
|
|
|
|
@torch.no_grad()
|
|
def sequential(
|
|
model,
|
|
dataloader,
|
|
dev,
|
|
nsamples,
|
|
bits,
|
|
groupsize,
|
|
*,
|
|
hooks,
|
|
percdamp=0.01,
|
|
sym: bool = False,
|
|
act_order: bool = False,
|
|
):
|
|
print("Starting ...")
|
|
|
|
use_cache = model.config.use_cache
|
|
model.config.use_cache = False
|
|
try:
|
|
layers = model.model.layers
|
|
prefix = "model.layers"
|
|
except Exception:
|
|
layers = model.transformer.h
|
|
prefix = "transformer.h"
|
|
|
|
dtype = next(iter(model.parameters())).dtype
|
|
inps = torch.zeros(
|
|
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
|
|
)
|
|
|
|
cache = {"i": 0}
|
|
extra = {}
|
|
|
|
class Catcher(nn.Module):
|
|
def __init__(self, module):
|
|
super().__init__()
|
|
self.module = module
|
|
|
|
def forward(self, inp, **kwargs):
|
|
inps[cache["i"]] = inp
|
|
cache["i"] += 1
|
|
extra.update(kwargs.copy())
|
|
raise ValueError
|
|
|
|
layers[0] = Catcher(layers[0])
|
|
for batch in dataloader:
|
|
try:
|
|
model(batch[0].cuda())
|
|
except ValueError:
|
|
pass
|
|
layers[0] = layers[0].module
|
|
|
|
# layers[0] = layers[0].cpu()
|
|
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
|
# model.model.norm = model.model.norm.cpu()
|
|
torch.cuda.empty_cache()
|
|
for hook in hooks:
|
|
hook.remove()
|
|
|
|
outs = torch.zeros_like(inps)
|
|
|
|
extra = {
|
|
k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items()
|
|
}
|
|
|
|
print("Ready.")
|
|
|
|
quantizers = {}
|
|
for i in range(len(layers)):
|
|
print(f"Quantizing layer {i+1}/{len(layers)}..")
|
|
print("+------------------+--------------+------------+-----------+-------+")
|
|
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
|
|
print("+==================+==============+============+===========+=======+")
|
|
|
|
layer = layers[i]
|
|
layer.load()
|
|
full = find_layers(layer)
|
|
sequential = [list(full.keys())]
|
|
|
|
for names in sequential:
|
|
subset = {n: full[n] for n in names}
|
|
gptq = {}
|
|
for name in subset:
|
|
gptq[name] = GPTQ(subset[name])
|
|
gptq[name].quantizer.configure(
|
|
bits, perchannel=True, sym=sym, mse=False
|
|
)
|
|
pass
|
|
|
|
def add_batch(name):
|
|
def tmp(_, inp, out):
|
|
gptq[name].add_batch(inp[0].data, out.data)
|
|
|
|
return tmp
|
|
|
|
handles = []
|
|
for name in subset:
|
|
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
|
for j in range(nsamples):
|
|
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
|
for h in handles:
|
|
h.remove()
|
|
|
|
for name in subset:
|
|
scale, zero, g_idx, error = gptq[name].fasterquant(
|
|
percdamp=percdamp,
|
|
groupsize=groupsize,
|
|
act_order=act_order,
|
|
name=name,
|
|
)
|
|
quantizers[f"{prefix}.{i}.{name}"] = (
|
|
gptq[name].quantizer.cpu(),
|
|
scale.cpu(),
|
|
zero.cpu(),
|
|
g_idx.cpu(),
|
|
bits,
|
|
groupsize,
|
|
)
|
|
|
|
gptq[name].free()
|
|
|
|
for j in range(nsamples):
|
|
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
|
|
|
layer.unload()
|
|
del layer
|
|
del gptq
|
|
torch.cuda.empty_cache()
|
|
|
|
inps, outs = outs, inps
|
|
print("+------------------+--------------+------------+-----------+-------+")
|
|
print("\n")
|
|
|
|
model.config.use_cache = use_cache
|
|
|
|
return quantizers
|
|
|
|
|
|
def make_quant_linear(module, names, bits, groupsize, name=""):
|
|
if isinstance(module, QuantLinear):
|
|
return
|
|
for attr in dir(module):
|
|
tmp = getattr(module, attr)
|
|
name1 = name + "." + attr if name != "" else attr
|
|
if name1 in names:
|
|
delattr(module, attr)
|
|
setattr(
|
|
module,
|
|
attr,
|
|
QuantLinear.new(
|
|
bits,
|
|
groupsize,
|
|
tmp.in_features,
|
|
tmp.out_features,
|
|
tmp.bias is not None,
|
|
),
|
|
)
|
|
for name1, child in module.named_children():
|
|
make_quant_linear(
|
|
child, names, bits, groupsize, name + "." + name1 if name != "" else name1
|
|
)
|
|
|
|
|
|
# TODO: perform packing on GPU
|
|
def pack(model, quantizers, bits, groupsize):
|
|
layers = find_layers(model)
|
|
layers = {n: layers[n] for n in quantizers}
|
|
make_quant_linear(model, quantizers, bits, groupsize)
|
|
qlayers = find_layers(model, (QuantLinear,))
|
|
print("Packing ...")
|
|
for name in qlayers:
|
|
print(name)
|
|
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
|
|
qlayers[name].pack(layers[name], scale, zero, g_idx)
|
|
print("Done.")
|
|
return model
|
|
|
|
|
|
def setdeepattr(module, full_name, tensor):
|
|
current = module
|
|
tokens = full_name.split(".")
|
|
for token in tokens[:-1]:
|
|
current = getattr(current, token)
|
|
setattr(current, tokens[-1], tensor)
|
|
|
|
|
|
def getdeepattr(module, full_name):
|
|
current = module
|
|
tokens = full_name.split(".")
|
|
for token in tokens:
|
|
current = getattr(current, token)
|
|
return current
|
|
|
|
|
|
def load_weights_pre_hook(module_name, weights, recursive=False):
|
|
def inner(module, args):
|
|
print(f"Pre hook {module_name}")
|
|
local_params = {}
|
|
for k, v in module.named_parameters():
|
|
if not recursive and k.count(".") != 1:
|
|
continue
|
|
local_params[k] = v
|
|
for k, v in module.named_buffers():
|
|
if not recursive and k.count(".") != 1:
|
|
continue
|
|
local_params[k] = v
|
|
|
|
for local_param in local_params:
|
|
current_tensor = getdeepattr(module, local_param)
|
|
if current_tensor.device == torch.device("meta"):
|
|
# print(f"Loading {local_param}")
|
|
if module_name:
|
|
tensor_name = f"{module_name}.{local_param}"
|
|
else:
|
|
tensor_name = local_param
|
|
tensor = weights.get_tensor(tensor_name)
|
|
setdeepattr(module, local_param, nn.Parameter(tensor))
|
|
else:
|
|
setdeepattr(
|
|
module,
|
|
local_param,
|
|
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))),
|
|
)
|
|
|
|
return inner
|
|
|
|
|
|
def load_weights_post_hook(module_name, weights, recursive=False):
|
|
def inner(module, args, output):
|
|
print(f"Post hook {module_name}")
|
|
local_params = {}
|
|
for k, v in module.named_parameters():
|
|
if not recursive and k.count(".") != 1:
|
|
continue
|
|
local_params[k] = v
|
|
for k, v in module.named_buffers():
|
|
if not recursive and k.count(".") != 1:
|
|
continue
|
|
local_params[k] = v
|
|
for local_param in local_params:
|
|
# print(f"Unloading {local_param}")
|
|
current_tensor = getdeepattr(module, local_param)
|
|
setdeepattr(
|
|
module,
|
|
local_param,
|
|
nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
|
|
)
|
|
return output
|
|
|
|
return inner
|
|
|
|
|
|
def quantize(
|
|
model_id: str,
|
|
bits: int,
|
|
groupsize: int,
|
|
output_dir: str,
|
|
revision: str,
|
|
trust_remote_code: bool,
|
|
upload_to_model_id: Optional[str],
|
|
percdamp: float,
|
|
act_order: bool,
|
|
):
|
|
print("loading model")
|
|
config = AutoConfig.from_pretrained(
|
|
model_id,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
with init_empty_weights():
|
|
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)
|
|
model = model.eval()
|
|
|
|
print("LOADED model")
|
|
files = weight_files(model_id, revision, extension=".safetensors")
|
|
process_group, _, _ = initialize_torch_distributed()
|
|
weights = Weights(
|
|
files,
|
|
device=torch.device("cuda:0"),
|
|
dtype=torch.float16,
|
|
process_group=process_group,
|
|
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
|
)
|
|
hooks = []
|
|
for name, module in model.named_modules():
|
|
|
|
def load(module, name):
|
|
def _load():
|
|
load_weights_pre_hook(name, weights, recursive=True)(module, None)
|
|
|
|
return _load
|
|
|
|
def unload(module, name):
|
|
def _unload():
|
|
load_weights_post_hook(name, weights, recursive=True)(
|
|
module, None, None
|
|
)
|
|
|
|
return _unload
|
|
|
|
module.load = load(module, name)
|
|
module.unload = unload(module, name)
|
|
hooks.append(
|
|
module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
|
|
)
|
|
hooks.append(
|
|
module.register_forward_hook(load_weights_post_hook(name, weights))
|
|
)
|
|
model.seqlen = 2048
|
|
|
|
dataset = "wikitext2"
|
|
nsamples = 128
|
|
seed = None
|
|
|
|
dataloader, testloader = get_loaders(
|
|
dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen
|
|
)
|
|
|
|
tick = time.time()
|
|
quantizers = sequential(
|
|
model,
|
|
dataloader,
|
|
DEV,
|
|
nsamples,
|
|
bits,
|
|
groupsize,
|
|
percdamp=percdamp,
|
|
act_order=act_order,
|
|
hooks=hooks,
|
|
)
|
|
print(time.time() - tick)
|
|
|
|
pack(model, quantizers, bits, groupsize)
|
|
from safetensors.torch import save_file
|
|
from transformers.modeling_utils import shard_checkpoint
|
|
|
|
state_dict = model.state_dict()
|
|
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
|
state_dict["gptq_bits"] = torch.LongTensor([bits])
|
|
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
|
|
|
|
max_shard_size = "10GB"
|
|
shards, index = shard_checkpoint(
|
|
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
|
|
)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
for shard_file, shard in shards.items():
|
|
save_file(
|
|
shard,
|
|
os.path.join(output_dir, shard_file),
|
|
metadata={
|
|
"format": "pt",
|
|
"quantized": "gptq",
|
|
"origin": "text-generation-inference",
|
|
},
|
|
)
|
|
if index is None:
|
|
path_to_weights = os.path.join(output_dir, "model.safetensors")
|
|
logger.info(f"Model weights saved in {path_to_weights}")
|
|
else:
|
|
save_index_file = "model.safetensors.index.json"
|
|
save_index_file = os.path.join(output_dir, save_index_file)
|
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
|
f.write(content)
|
|
logger.info(
|
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
|
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
|
f"index located at {save_index_file}."
|
|
)
|
|
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
|
config.save_pretrained(output_dir)
|
|
logger.info("Saved config")
|
|
logger.info("Saving tokenizer")
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id, trust_remote_code=trust_remote_code
|
|
)
|
|
tokenizer.save_pretrained(output_dir)
|
|
logger.info("Saved tokenizer")
|
|
|
|
if upload_to_model_id:
|
|
api = HfApi()
|
|
|
|
api.upload_folder(
|
|
folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model"
|
|
)
|