This commit is contained in:
Nicolas Patry 2023-06-14 00:08:33 +02:00
parent ee1f94e64b
commit e5e552b496

View File

@ -19,14 +19,23 @@ DEV = torch.device("cuda:0")
class Quantizer(nn.Module): class Quantizer(nn.Module):
def __init__(self, shape=1): def __init__(self, shape=1):
super(Quantizer, self).__init__() super(Quantizer, self).__init__()
self.register_buffer('maxq', torch.tensor(0)) self.register_buffer("maxq", torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape)) self.register_buffer("scale", torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape)) self.register_buffer("zero", torch.zeros(shape))
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False): def configure(
self,
bits,
perchannel=False,
sym=True,
mse=False,
norm=2.4,
grid=100,
maxshrink=0.8,
trits=False,
):
self.maxq = torch.tensor(2**bits - 1) self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel self.perchannel = perchannel
@ -88,14 +97,16 @@ class Quantizer(nn.Module):
self.zero = torch.round(-xmin / self.scale) self.zero = torch.round(-xmin / self.scale)
if self.mse: if self.mse:
best = torch.full([x.shape[0]], float('inf'), device=dev) best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)): for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid p = 1 - i / self.grid
xmin1 = p * xmin xmin1 = p * xmin
xmax1 = p * xmax xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) q = self._quantize(
x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
)
q -= x q -= x
q.abs_() q.abs_()
q.pow_(self.norm) q.pow_(self.norm)
@ -142,7 +153,6 @@ class Quantizer(nn.Module):
class GPTQ: class GPTQ:
def __init__(self, layer, observe=False): def __init__(self, layer, observe=False):
self.layer = layer self.layer = layer
self.dev = self.layer.weight.device self.dev = self.layer.weight.device
@ -170,12 +180,19 @@ class GPTQ:
if len(inp.shape) == 2: if len(inp.shape) == 2:
inp = inp.unsqueeze(0) inp = inp.unsqueeze(0)
tmp = inp.shape[0] tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): if isinstance(self.layer, nn.Linear) or isinstance(
self.layer, transformers.Conv1D
):
if len(inp.shape) == 3: if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1])) inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t() inp = inp.t()
if isinstance(self.layer, nn.Conv2d): if isinstance(self.layer, nn.Conv2d):
unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride) unfold = nn.Unfold(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.stride,
)
inp = unfold(inp) inp = unfold(inp)
inp = inp.permute([1, 0, 2]) inp = inp.permute([1, 0, 2])
inp = inp.flatten(1) inp = inp.flatten(1)
@ -188,12 +205,19 @@ class GPTQ:
def print_loss(self, name, q_weight, weight_error, timecost): def print_loss(self, name, q_weight, weight_error, timecost):
table = Texttable() table = Texttable()
name += ' ' * (16 - len(name)) length = 30
name = (
(name + " " * (length - len(name)))
if len(name) <= length
else name[:length]
)
table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time']) table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])
# assign weight # assign weight
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
self.layer.weight.data.dtype
)
if self.inp1 is not None: if self.inp1 is not None:
# quantize input to int8 # quantize input to int8
@ -207,13 +231,15 @@ class GPTQ:
q_SNR = torch_snr_error(q_out, self.out1).item() q_SNR = torch_snr_error(q_out, self.out1).item()
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
else: else:
q_SNR = '-' q_SNR = "-"
fp_SNR = '-' fp_SNR = "-"
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
print(table.draw().split('\n')[-2]) print(table.draw().split("\n")[-2])
def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''): def fasterquant(
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name=""
):
self.layer.to(self.dev) self.layer.to(self.dev)
W = self.layer.weight.data.clone() W = self.layer.weight.data.clone()
@ -248,7 +274,13 @@ class GPTQ:
H[diag, diag] += damp H[diag, diag] += damp
H = torch.linalg.cholesky(H) H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H) H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True) try:
H = torch.linalg.cholesky(H, upper=True)
except Exception:
# Addition because Falcon fails on h_to_4h
H = torch.linalg.cholesky(
H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True
)
Hinv = H Hinv = H
g_idx = [] g_idx = []
@ -272,7 +304,9 @@ class GPTQ:
if groupsize != -1: if groupsize != -1:
if (i1 + i) % groupsize == 0: if (i1 + i) % groupsize == 0:
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) self.quantizer.find_params(
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
)
if ((i1 + i) // groupsize) - now_idx == -1: if ((i1 + i) // groupsize) - now_idx == -1:
scale.append(self.quantizer.scale) scale.append(self.quantizer.scale)
@ -281,7 +315,7 @@ class GPTQ:
q = self.quantizer.quantize(w.unsqueeze(1)).flatten() q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q Q1[:, i] = q
Losses1[:, i] = (w - q)**2 / d**2 Losses1[:, i] = (w - q) ** 2 / d**2
err1 = (w - q) / d err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
@ -306,7 +340,9 @@ class GPTQ:
if isinstance(self.layer, transformers.Conv1D): if isinstance(self.layer, transformers.Conv1D):
Q = Q.t() Q = Q.t()
self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)) self.print_loss(
name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
)
if scale == []: if scale == []:
scale.append(self.quantizer.scale) scale.append(self.quantizer.scale)
@ -326,15 +362,18 @@ class GPTQ:
def get_wikitext2(nsamples, seed, seqlen, model_id): def get_wikitext2(nsamples, seed, seqlen, model_id):
from datasets import load_dataset from datasets import load_dataset
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
from transformers import AutoTokenizer from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
import random import random
random.seed(seed) random.seed(seed)
trainloader = [] trainloader = []
for _ in range(nsamples): for _ in range(nsamples):
@ -349,18 +388,21 @@ def get_wikitext2(nsamples, seed, seqlen, model_id):
def get_ptb(nsamples, seed, seqlen, model_id): def get_ptb(nsamples, seed, seqlen, model_id):
from datasets import load_dataset from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
from transformers import AutoTokenizer from transformers import AutoTokenizer
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
except: except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
import random import random
random.seed(seed) random.seed(seed)
trainloader = [] trainloader = []
for _ in range(nsamples): for _ in range(nsamples):
@ -375,22 +417,37 @@ def get_ptb(nsamples, seed, seqlen, model_id):
def get_c4(nsamples, seed, seqlen, model_id): def get_c4(nsamples, seed, seqlen, model_id):
from datasets import load_dataset from datasets import load_dataset
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False)
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False) traindata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
split="train",
use_auth_token=False,
)
valdata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
split="validation",
use_auth_token=False,
)
from transformers import AutoTokenizer from transformers import AutoTokenizer
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
except: except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
import random import random
random.seed(seed) random.seed(seed)
trainloader = [] trainloader = []
for _ in range(nsamples): for _ in range(nsamples):
while True: while True:
i = random.randint(0, len(traindata) - 1) i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
if trainenc.input_ids.shape[1] >= seqlen: if trainenc.input_ids.shape[1] >= seqlen:
break break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
@ -401,12 +458,13 @@ def get_c4(nsamples, seed, seqlen, model_id):
trainloader.append((inp, tar)) trainloader.append((inp, tar))
import random import random
random.seed(0) random.seed(0)
valenc = [] valenc = []
for _ in range(256): for _ in range(256):
while True: while True:
i = random.randint(0, len(valdata) - 1) i = random.randint(0, len(valdata) - 1)
tmp = tokenizer(valdata[i]['text'], return_tensors='pt') tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
if tmp.input_ids.shape[1] >= seqlen: if tmp.input_ids.shape[1] >= seqlen:
break break
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
@ -415,7 +473,6 @@ def get_c4(nsamples, seed, seqlen, model_id):
valenc = torch.hstack(valenc) valenc = torch.hstack(valenc)
class TokenizerWrapper: class TokenizerWrapper:
def __init__(self, input_ids): def __init__(self, input_ids):
self.input_ids = input_ids self.input_ids = input_ids
@ -426,18 +483,21 @@ def get_c4(nsamples, seed, seqlen, model_id):
def get_ptb_new(nsamples, seed, seqlen, model_id): def get_ptb_new(nsamples, seed, seqlen, model_id):
from datasets import load_dataset from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
from transformers import AutoTokenizer from transformers import AutoTokenizer
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
except: except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
import random import random
random.seed(seed) random.seed(seed)
trainloader = [] trainloader = []
for _ in range(nsamples): for _ in range(nsamples):
@ -452,22 +512,35 @@ def get_ptb_new(nsamples, seed, seqlen, model_id):
def get_c4_new(nsamples, seed, seqlen, model_id): def get_c4_new(nsamples, seed, seqlen, model_id):
from datasets import load_dataset from datasets import load_dataset
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') traindata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
split="train",
)
valdata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
split="validation",
)
from transformers import AutoTokenizer from transformers import AutoTokenizer
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
except: except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
import random import random
random.seed(seed) random.seed(seed)
trainloader = [] trainloader = []
for _ in range(nsamples): for _ in range(nsamples):
while True: while True:
i = random.randint(0, len(traindata) - 1) i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
if trainenc.input_ids.shape[1] >= seqlen: if trainenc.input_ids.shape[1] >= seqlen:
break break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
@ -477,11 +550,10 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
tar[:, :-1] = -100 tar[:, :-1] = -100
trainloader.append((inp, tar)) trainloader.append((inp, tar))
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
valenc = valenc.input_ids[:, :(256 * seqlen)] valenc = valenc.input_ids[:, : (256 * seqlen)]
class TokenizerWrapper: class TokenizerWrapper:
def __init__(self, input_ids): def __init__(self, input_ids):
self.input_ids = input_ids self.input_ids = input_ids
@ -490,35 +562,56 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
return trainloader, valenc return trainloader, valenc
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=''): def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""):
if 'wikitext2' in name: if "wikitext2" in name:
return get_wikitext2(nsamples, seed, seqlen, model_id) return get_wikitext2(nsamples, seed, seqlen, model_id)
if 'ptb' in name: if "ptb" in name:
if 'new' in name: if "new" in name:
return get_ptb_new(nsamples, seed, seqlen, model_id) return get_ptb_new(nsamples, seed, seqlen, model_id)
return get_ptb(nsamples, seed, seqlen, model_id) return get_ptb(nsamples, seed, seqlen, model_id)
if 'c4' in name: if "c4" in name:
if 'new' in name: if "new" in name:
return get_c4_new(nsamples, seed, seqlen, model_id) return get_c4_new(nsamples, seed, seqlen, model_id)
return get_c4(nsamples, seed, seqlen, model_id) return get_c4(nsamples, seed, seqlen, model_id)
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""):
# Skip last lm_head linear # Skip last lm_head linear
if type(module) in layers and "lm_head" not in name: # Need isintance Falcon is inheriting Linear.
if isinstance(module, layers) and "lm_head" not in name:
return {name: module} return {name: module}
res = {} res = {}
for name1, child in module.named_children(): for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) res.update(
find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res return res
@torch.no_grad() @torch.no_grad()
def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01, sym: bool=False, act_order: bool = False): def sequential(
print('Starting ...') model,
dataloader,
dev,
nsamples,
bits,
groupsize,
percdamp=0.01,
sym: bool = False,
act_order: bool = False,
):
print("Starting ...")
use_cache = model.config.use_cache use_cache = model.config.use_cache
model.config.use_cache = False model.config.use_cache = False
layers = model.model.layers try:
layers = model.model.layers
prefix = "model.layers"
except Exception:
layers = model.transformer.h
prefix = "transformer.h"
# embeddings = model.get_input_embeddings() # embeddings = model.get_input_embeddings()
# embeddings = embeddings.to(dev) # embeddings = embeddings.to(dev)
@ -528,20 +621,22 @@ def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01,
# layers[0] = layers[0].to(dev) # layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype dtype = next(iter(model.parameters())).dtype
inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) inps = torch.zeros(
cache = {'i': 0, 'attention_mask': None} (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
)
cache = {"i": 0}
extra = {}
class Catcher(nn.Module): class Catcher(nn.Module):
def __init__(self, module): def __init__(self, module):
super().__init__() super().__init__()
self.module = module self.module = module
def forward(self, inp, **kwargs): def forward(self, inp, **kwargs):
inps[cache['i']] = inp inps[cache["i"]] = inp
cache['i'] += 1 cache["i"] += 1
cache['attention_mask'] = kwargs['attention_mask'] extra.update(kwargs.copy())
cache['position_ids'] = kwargs['position_ids']
raise ValueError raise ValueError
layers[0] = Catcher(layers[0]) layers[0] = Catcher(layers[0])
@ -558,19 +653,22 @@ def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01,
torch.cuda.empty_cache() torch.cuda.empty_cache()
outs = torch.zeros_like(inps) outs = torch.zeros_like(inps)
attention_mask = cache['attention_mask'].to(dev)
position_ids = cache['position_ids'].to(dev)
print('Ready.') extra = {
k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items()
}
print("Ready.")
quantizers = {} quantizers = {}
for i in range(len(layers)): for i in range(len(layers)):
print(f'Quantizing layer {i+1}/{len(layers)}..') print(f"Quantizing layer {i+1}/{len(layers)}..")
print('+------------------+--------------+------------+-----------+-------+') print("+------------------+--------------+------------+-----------+-------+")
print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |') print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
print('+==================+==============+============+===========+=======+') print("+==================+==============+============+===========+=======+")
from accelerate.hooks import remove_hook_from_submodules from accelerate.hooks import remove_hook_from_submodules
layer = layers[i].to(dev) layer = layers[i].to(dev)
remove_hook_from_submodules(layer) remove_hook_from_submodules(layer)
full = find_layers(layer) full = find_layers(layer)
@ -581,10 +679,11 @@ def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01,
gptq = {} gptq = {}
for name in subset: for name in subset:
gptq[name] = GPTQ(subset[name]) gptq[name] = GPTQ(subset[name])
gptq[name].quantizer.configure(bits, perchannel=True, sym=sym, mse=False) gptq[name].quantizer.configure(
bits, perchannel=True, sym=sym, mse=False
)
def add_batch(name): def add_batch(name):
def tmp(_, inp, out): def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data) gptq[name].add_batch(inp[0].data, out.data)
@ -595,18 +694,30 @@ def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01,
handles.append(subset[name].register_forward_hook(add_batch(name))) handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(nsamples): for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
for h in handles: for h in handles:
h.remove() h.remove()
for name in subset: for name in subset:
scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=percdamp, groupsize=groupsize, actorder=act_order, name=name) scale, zero, g_idx, error = gptq[name].fasterquant(
quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), bits, groupsize) percdamp=percdamp,
groupsize=groupsize,
actorder=act_order,
name=name,
)
quantizers[f"{prefix}.{i}.{name}"] = (
gptq[name].quantizer.cpu(),
scale.cpu(),
zero.cpu(),
g_idx.cpu(),
bits,
groupsize,
)
gptq[name].free() gptq[name].free()
for j in range(nsamples): for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
layers[i] = layer.cpu() layers[i] = layer.cpu()
del layer del layer
@ -614,8 +725,8 @@ def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01,
torch.cuda.empty_cache() torch.cuda.empty_cache()
inps, outs = outs, inps inps, outs = outs, inps
print('+------------------+--------------+------------+-----------+-------+') print("+------------------+--------------+------------+-----------+-------+")
print('\n') print("\n")
# if args.observe: # if args.observe:
# observer.print() # observer.print()
@ -746,18 +857,30 @@ def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01,
# #
# model.config.use_cache = use_cache # model.config.use_cache = use_cache
def make_quant_linear(module, names, bits, groupsize, name=''):
def make_quant_linear(module, names, bits, groupsize, name=""):
if isinstance(module, QuantLinear): if isinstance(module, QuantLinear):
return return
for attr in dir(module): for attr in dir(module):
tmp = getattr(module, attr) tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr name1 = name + "." + attr if name != "" else attr
if name1 in names: if name1 in names:
delattr(module, attr) delattr(module, attr)
setattr(module, attr, QuantLinear.new(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) setattr(
module,
attr,
QuantLinear.new(
bits,
groupsize,
tmp.in_features,
tmp.out_features,
tmp.bias is not None,
),
)
for name1, child in module.named_children(): for name1, child in module.named_children():
make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) make_quant_linear(
child, names, bits, groupsize, name + "." + name1 if name != "" else name1
)
# TODO: perform packing on GPU # TODO: perform packing on GPU
@ -765,13 +888,13 @@ def pack(model, quantizers, bits, groupsize):
layers = find_layers(model) layers = find_layers(model)
layers = {n: layers[n] for n in quantizers} layers = {n: layers[n] for n in quantizers}
make_quant_linear(model, quantizers, bits, groupsize) make_quant_linear(model, quantizers, bits, groupsize)
qlayers = find_layers(model, [QuantLinear]) qlayers = find_layers(model, (QuantLinear,))
print('Packing ...') print("Packing ...")
for name in qlayers: for name in qlayers:
print(name) print(name)
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
qlayers[name].pack(layers[name], scale, zero, g_idx) qlayers[name].pack(layers[name], scale, zero, g_idx)
print('Done.') print("Done.")
return model return model
@ -937,9 +1060,16 @@ def pack(model, quantizers, bits, groupsize):
# print('max memory(MiB):', max_memory) # print('max memory(MiB):', max_memory)
def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_remote_code: bool): def quantize(
model_id: str, bits: int, groupsize: int, output_dir: str, trust_remote_code: bool
):
print("loading model") print("loading model")
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0", trust_remote_code=trust_remote_code) model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="balanced_low_0",
trust_remote_code=trust_remote_code,
)
print("LOADED model") print("LOADED model")
model.seqlen = 2048 model.seqlen = 2048
@ -947,8 +1077,9 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_re
nsamples = 128 nsamples = 128
seed = None seed = None
dataloader, testloader = get_loaders(
dataloader, testloader = get_loaders(dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen) dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen
)
tick = time.time() tick = time.time()
quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize) quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize)
@ -988,7 +1119,6 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_re
# generated_ids = model.generate(input_ids, streamer=streamer) # generated_ids = model.generate(input_ids, streamer=streamer)
# #
# if args.quant_directory is not None: # if args.quant_directory is not None:
# export_quant_table(quantizers, args.quant_directory) # export_quant_table(quantizers, args.quant_directory)
@ -1000,16 +1130,27 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_re
pack(model, quantizers, bits, groupsize) pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file from safetensors.torch import save_file
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
state_dict = model.state_dict() state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
state_dict["gptq_bits"] = torch.LongTensor([bits]) state_dict["gptq_bits"] = torch.LongTensor([bits])
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
max_shard_size = "10GB" max_shard_size = "10GB"
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors") shards, index = shard_checkpoint(
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items(): for shard_file, shard in shards.items():
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt", "quantized": "gptq", "origin": "text-generation-inference"}) save_file(
shard,
os.path.join(output_dir, shard_file),
metadata={
"format": "pt",
"quantized": "gptq",
"origin": "text-generation-inference",
},
)
if index is None: if index is None:
path_to_weights = os.path.join(output_dir, "model.safetensors") path_to_weights = os.path.join(output_dir, "model.safetensors")
logger.info(f"Model weights saved in {path_to_weights}") logger.info(f"Model weights saved in {path_to_weights}")