mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Falcon
This commit is contained in:
parent
ee1f94e64b
commit
e5e552b496
@ -19,14 +19,23 @@ DEV = torch.device("cuda:0")
|
|||||||
|
|
||||||
|
|
||||||
class Quantizer(nn.Module):
|
class Quantizer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, shape=1):
|
def __init__(self, shape=1):
|
||||||
super(Quantizer, self).__init__()
|
super(Quantizer, self).__init__()
|
||||||
self.register_buffer('maxq', torch.tensor(0))
|
self.register_buffer("maxq", torch.tensor(0))
|
||||||
self.register_buffer('scale', torch.zeros(shape))
|
self.register_buffer("scale", torch.zeros(shape))
|
||||||
self.register_buffer('zero', torch.zeros(shape))
|
self.register_buffer("zero", torch.zeros(shape))
|
||||||
|
|
||||||
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False):
|
def configure(
|
||||||
|
self,
|
||||||
|
bits,
|
||||||
|
perchannel=False,
|
||||||
|
sym=True,
|
||||||
|
mse=False,
|
||||||
|
norm=2.4,
|
||||||
|
grid=100,
|
||||||
|
maxshrink=0.8,
|
||||||
|
trits=False,
|
||||||
|
):
|
||||||
|
|
||||||
self.maxq = torch.tensor(2**bits - 1)
|
self.maxq = torch.tensor(2**bits - 1)
|
||||||
self.perchannel = perchannel
|
self.perchannel = perchannel
|
||||||
@ -88,14 +97,16 @@ class Quantizer(nn.Module):
|
|||||||
self.zero = torch.round(-xmin / self.scale)
|
self.zero = torch.round(-xmin / self.scale)
|
||||||
|
|
||||||
if self.mse:
|
if self.mse:
|
||||||
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
best = torch.full([x.shape[0]], float("inf"), device=dev)
|
||||||
for i in range(int(self.maxshrink * self.grid)):
|
for i in range(int(self.maxshrink * self.grid)):
|
||||||
p = 1 - i / self.grid
|
p = 1 - i / self.grid
|
||||||
xmin1 = p * xmin
|
xmin1 = p * xmin
|
||||||
xmax1 = p * xmax
|
xmax1 = p * xmax
|
||||||
scale1 = (xmax1 - xmin1) / self.maxq
|
scale1 = (xmax1 - xmin1) / self.maxq
|
||||||
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
||||||
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
|
q = self._quantize(
|
||||||
|
x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
|
||||||
|
)
|
||||||
q -= x
|
q -= x
|
||||||
q.abs_()
|
q.abs_()
|
||||||
q.pow_(self.norm)
|
q.pow_(self.norm)
|
||||||
@ -142,7 +153,6 @@ class Quantizer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GPTQ:
|
class GPTQ:
|
||||||
|
|
||||||
def __init__(self, layer, observe=False):
|
def __init__(self, layer, observe=False):
|
||||||
self.layer = layer
|
self.layer = layer
|
||||||
self.dev = self.layer.weight.device
|
self.dev = self.layer.weight.device
|
||||||
@ -170,12 +180,19 @@ class GPTQ:
|
|||||||
if len(inp.shape) == 2:
|
if len(inp.shape) == 2:
|
||||||
inp = inp.unsqueeze(0)
|
inp = inp.unsqueeze(0)
|
||||||
tmp = inp.shape[0]
|
tmp = inp.shape[0]
|
||||||
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
|
if isinstance(self.layer, nn.Linear) or isinstance(
|
||||||
|
self.layer, transformers.Conv1D
|
||||||
|
):
|
||||||
if len(inp.shape) == 3:
|
if len(inp.shape) == 3:
|
||||||
inp = inp.reshape((-1, inp.shape[-1]))
|
inp = inp.reshape((-1, inp.shape[-1]))
|
||||||
inp = inp.t()
|
inp = inp.t()
|
||||||
if isinstance(self.layer, nn.Conv2d):
|
if isinstance(self.layer, nn.Conv2d):
|
||||||
unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride)
|
unfold = nn.Unfold(
|
||||||
|
self.layer.kernel_size,
|
||||||
|
dilation=self.layer.dilation,
|
||||||
|
padding=self.layer.padding,
|
||||||
|
stride=self.layer.stride,
|
||||||
|
)
|
||||||
inp = unfold(inp)
|
inp = unfold(inp)
|
||||||
inp = inp.permute([1, 0, 2])
|
inp = inp.permute([1, 0, 2])
|
||||||
inp = inp.flatten(1)
|
inp = inp.flatten(1)
|
||||||
@ -188,12 +205,19 @@ class GPTQ:
|
|||||||
|
|
||||||
def print_loss(self, name, q_weight, weight_error, timecost):
|
def print_loss(self, name, q_weight, weight_error, timecost):
|
||||||
table = Texttable()
|
table = Texttable()
|
||||||
name += ' ' * (16 - len(name))
|
length = 30
|
||||||
|
name = (
|
||||||
|
(name + " " * (length - len(name)))
|
||||||
|
if len(name) <= length
|
||||||
|
else name[:length]
|
||||||
|
)
|
||||||
|
|
||||||
table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time'])
|
table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])
|
||||||
|
|
||||||
# assign weight
|
# assign weight
|
||||||
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
|
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
|
||||||
|
self.layer.weight.data.dtype
|
||||||
|
)
|
||||||
|
|
||||||
if self.inp1 is not None:
|
if self.inp1 is not None:
|
||||||
# quantize input to int8
|
# quantize input to int8
|
||||||
@ -207,13 +231,15 @@ class GPTQ:
|
|||||||
q_SNR = torch_snr_error(q_out, self.out1).item()
|
q_SNR = torch_snr_error(q_out, self.out1).item()
|
||||||
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
|
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
|
||||||
else:
|
else:
|
||||||
q_SNR = '-'
|
q_SNR = "-"
|
||||||
fp_SNR = '-'
|
fp_SNR = "-"
|
||||||
|
|
||||||
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
|
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
|
||||||
print(table.draw().split('\n')[-2])
|
print(table.draw().split("\n")[-2])
|
||||||
|
|
||||||
def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''):
|
def fasterquant(
|
||||||
|
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name=""
|
||||||
|
):
|
||||||
self.layer.to(self.dev)
|
self.layer.to(self.dev)
|
||||||
|
|
||||||
W = self.layer.weight.data.clone()
|
W = self.layer.weight.data.clone()
|
||||||
@ -248,7 +274,13 @@ class GPTQ:
|
|||||||
H[diag, diag] += damp
|
H[diag, diag] += damp
|
||||||
H = torch.linalg.cholesky(H)
|
H = torch.linalg.cholesky(H)
|
||||||
H = torch.cholesky_inverse(H)
|
H = torch.cholesky_inverse(H)
|
||||||
|
try:
|
||||||
H = torch.linalg.cholesky(H, upper=True)
|
H = torch.linalg.cholesky(H, upper=True)
|
||||||
|
except Exception:
|
||||||
|
# Addition because Falcon fails on h_to_4h
|
||||||
|
H = torch.linalg.cholesky(
|
||||||
|
H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True
|
||||||
|
)
|
||||||
Hinv = H
|
Hinv = H
|
||||||
|
|
||||||
g_idx = []
|
g_idx = []
|
||||||
@ -272,7 +304,9 @@ class GPTQ:
|
|||||||
|
|
||||||
if groupsize != -1:
|
if groupsize != -1:
|
||||||
if (i1 + i) % groupsize == 0:
|
if (i1 + i) % groupsize == 0:
|
||||||
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
|
self.quantizer.find_params(
|
||||||
|
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
|
||||||
|
)
|
||||||
|
|
||||||
if ((i1 + i) // groupsize) - now_idx == -1:
|
if ((i1 + i) // groupsize) - now_idx == -1:
|
||||||
scale.append(self.quantizer.scale)
|
scale.append(self.quantizer.scale)
|
||||||
@ -281,7 +315,7 @@ class GPTQ:
|
|||||||
|
|
||||||
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
|
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
|
||||||
Q1[:, i] = q
|
Q1[:, i] = q
|
||||||
Losses1[:, i] = (w - q)**2 / d**2
|
Losses1[:, i] = (w - q) ** 2 / d**2
|
||||||
|
|
||||||
err1 = (w - q) / d
|
err1 = (w - q) / d
|
||||||
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
||||||
@ -306,7 +340,9 @@ class GPTQ:
|
|||||||
if isinstance(self.layer, transformers.Conv1D):
|
if isinstance(self.layer, transformers.Conv1D):
|
||||||
Q = Q.t()
|
Q = Q.t()
|
||||||
|
|
||||||
self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick))
|
self.print_loss(
|
||||||
|
name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
|
||||||
|
)
|
||||||
|
|
||||||
if scale == []:
|
if scale == []:
|
||||||
scale.append(self.quantizer.scale)
|
scale.append(self.quantizer.scale)
|
||||||
@ -326,15 +362,18 @@ class GPTQ:
|
|||||||
|
|
||||||
def get_wikitext2(nsamples, seed, seqlen, model_id):
|
def get_wikitext2(nsamples, seed, seqlen, model_id):
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
|
|
||||||
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
|
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
||||||
|
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||||
trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
|
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
|
||||||
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
|
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
trainloader = []
|
trainloader = []
|
||||||
for _ in range(nsamples):
|
for _ in range(nsamples):
|
||||||
@ -349,18 +388,21 @@ def get_wikitext2(nsamples, seed, seqlen, model_id):
|
|||||||
|
|
||||||
def get_ptb(nsamples, seed, seqlen, model_id):
|
def get_ptb(nsamples, seed, seqlen, model_id):
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
|
|
||||||
valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
|
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
||||||
|
valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||||
except:
|
except:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||||
trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt')
|
trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
|
||||||
testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt')
|
testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
trainloader = []
|
trainloader = []
|
||||||
for _ in range(nsamples):
|
for _ in range(nsamples):
|
||||||
@ -375,22 +417,37 @@ def get_ptb(nsamples, seed, seqlen, model_id):
|
|||||||
|
|
||||||
def get_c4(nsamples, seed, seqlen, model_id):
|
def get_c4(nsamples, seed, seqlen, model_id):
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False)
|
|
||||||
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False)
|
traindata = load_dataset(
|
||||||
|
"allenai/c4",
|
||||||
|
"allenai--c4",
|
||||||
|
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
||||||
|
split="train",
|
||||||
|
use_auth_token=False,
|
||||||
|
)
|
||||||
|
valdata = load_dataset(
|
||||||
|
"allenai/c4",
|
||||||
|
"allenai--c4",
|
||||||
|
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
||||||
|
split="validation",
|
||||||
|
use_auth_token=False,
|
||||||
|
)
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||||
except:
|
except:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
trainloader = []
|
trainloader = []
|
||||||
for _ in range(nsamples):
|
for _ in range(nsamples):
|
||||||
while True:
|
while True:
|
||||||
i = random.randint(0, len(traindata) - 1)
|
i = random.randint(0, len(traindata) - 1)
|
||||||
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
|
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
|
||||||
if trainenc.input_ids.shape[1] >= seqlen:
|
if trainenc.input_ids.shape[1] >= seqlen:
|
||||||
break
|
break
|
||||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||||
@ -401,12 +458,13 @@ def get_c4(nsamples, seed, seqlen, model_id):
|
|||||||
trainloader.append((inp, tar))
|
trainloader.append((inp, tar))
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
valenc = []
|
valenc = []
|
||||||
for _ in range(256):
|
for _ in range(256):
|
||||||
while True:
|
while True:
|
||||||
i = random.randint(0, len(valdata) - 1)
|
i = random.randint(0, len(valdata) - 1)
|
||||||
tmp = tokenizer(valdata[i]['text'], return_tensors='pt')
|
tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
|
||||||
if tmp.input_ids.shape[1] >= seqlen:
|
if tmp.input_ids.shape[1] >= seqlen:
|
||||||
break
|
break
|
||||||
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
|
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
|
||||||
@ -415,7 +473,6 @@ def get_c4(nsamples, seed, seqlen, model_id):
|
|||||||
valenc = torch.hstack(valenc)
|
valenc = torch.hstack(valenc)
|
||||||
|
|
||||||
class TokenizerWrapper:
|
class TokenizerWrapper:
|
||||||
|
|
||||||
def __init__(self, input_ids):
|
def __init__(self, input_ids):
|
||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
|
|
||||||
@ -426,18 +483,21 @@ def get_c4(nsamples, seed, seqlen, model_id):
|
|||||||
|
|
||||||
def get_ptb_new(nsamples, seed, seqlen, model_id):
|
def get_ptb_new(nsamples, seed, seqlen, model_id):
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
|
|
||||||
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
|
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
||||||
|
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||||
except:
|
except:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||||
trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
|
trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
|
||||||
testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
|
testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
trainloader = []
|
trainloader = []
|
||||||
for _ in range(nsamples):
|
for _ in range(nsamples):
|
||||||
@ -452,22 +512,35 @@ def get_ptb_new(nsamples, seed, seqlen, model_id):
|
|||||||
|
|
||||||
def get_c4_new(nsamples, seed, seqlen, model_id):
|
def get_c4_new(nsamples, seed, seqlen, model_id):
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
|
|
||||||
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
|
traindata = load_dataset(
|
||||||
|
"allenai/c4",
|
||||||
|
"allenai--c4",
|
||||||
|
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
valdata = load_dataset(
|
||||||
|
"allenai/c4",
|
||||||
|
"allenai--c4",
|
||||||
|
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
||||||
|
split="validation",
|
||||||
|
)
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||||
except:
|
except:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
trainloader = []
|
trainloader = []
|
||||||
for _ in range(nsamples):
|
for _ in range(nsamples):
|
||||||
while True:
|
while True:
|
||||||
i = random.randint(0, len(traindata) - 1)
|
i = random.randint(0, len(traindata) - 1)
|
||||||
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
|
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
|
||||||
if trainenc.input_ids.shape[1] >= seqlen:
|
if trainenc.input_ids.shape[1] >= seqlen:
|
||||||
break
|
break
|
||||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||||
@ -477,11 +550,10 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
|
|||||||
tar[:, :-1] = -100
|
tar[:, :-1] = -100
|
||||||
trainloader.append((inp, tar))
|
trainloader.append((inp, tar))
|
||||||
|
|
||||||
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
|
valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
|
||||||
valenc = valenc.input_ids[:, :(256 * seqlen)]
|
valenc = valenc.input_ids[:, : (256 * seqlen)]
|
||||||
|
|
||||||
class TokenizerWrapper:
|
class TokenizerWrapper:
|
||||||
|
|
||||||
def __init__(self, input_ids):
|
def __init__(self, input_ids):
|
||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
|
|
||||||
@ -490,35 +562,56 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
|
|||||||
return trainloader, valenc
|
return trainloader, valenc
|
||||||
|
|
||||||
|
|
||||||
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=''):
|
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""):
|
||||||
if 'wikitext2' in name:
|
if "wikitext2" in name:
|
||||||
return get_wikitext2(nsamples, seed, seqlen, model_id)
|
return get_wikitext2(nsamples, seed, seqlen, model_id)
|
||||||
if 'ptb' in name:
|
if "ptb" in name:
|
||||||
if 'new' in name:
|
if "new" in name:
|
||||||
return get_ptb_new(nsamples, seed, seqlen, model_id)
|
return get_ptb_new(nsamples, seed, seqlen, model_id)
|
||||||
return get_ptb(nsamples, seed, seqlen, model_id)
|
return get_ptb(nsamples, seed, seqlen, model_id)
|
||||||
if 'c4' in name:
|
if "c4" in name:
|
||||||
if 'new' in name:
|
if "new" in name:
|
||||||
return get_c4_new(nsamples, seed, seqlen, model_id)
|
return get_c4_new(nsamples, seed, seqlen, model_id)
|
||||||
return get_c4(nsamples, seed, seqlen, model_id)
|
return get_c4(nsamples, seed, seqlen, model_id)
|
||||||
|
|
||||||
|
|
||||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""):
|
||||||
# Skip last lm_head linear
|
# Skip last lm_head linear
|
||||||
if type(module) in layers and "lm_head" not in name:
|
# Need isintance Falcon is inheriting Linear.
|
||||||
|
if isinstance(module, layers) and "lm_head" not in name:
|
||||||
return {name: module}
|
return {name: module}
|
||||||
res = {}
|
res = {}
|
||||||
for name1, child in module.named_children():
|
for name1, child in module.named_children():
|
||||||
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
|
res.update(
|
||||||
|
find_layers(
|
||||||
|
child, layers=layers, name=name + "." + name1 if name != "" else name1
|
||||||
|
)
|
||||||
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01, sym: bool=False, act_order: bool = False):
|
def sequential(
|
||||||
print('Starting ...')
|
model,
|
||||||
|
dataloader,
|
||||||
|
dev,
|
||||||
|
nsamples,
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
percdamp=0.01,
|
||||||
|
sym: bool = False,
|
||||||
|
act_order: bool = False,
|
||||||
|
):
|
||||||
|
print("Starting ...")
|
||||||
|
|
||||||
use_cache = model.config.use_cache
|
use_cache = model.config.use_cache
|
||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
|
try:
|
||||||
layers = model.model.layers
|
layers = model.model.layers
|
||||||
|
prefix = "model.layers"
|
||||||
|
except Exception:
|
||||||
|
layers = model.transformer.h
|
||||||
|
prefix = "transformer.h"
|
||||||
|
|
||||||
# embeddings = model.get_input_embeddings()
|
# embeddings = model.get_input_embeddings()
|
||||||
# embeddings = embeddings.to(dev)
|
# embeddings = embeddings.to(dev)
|
||||||
@ -528,20 +621,22 @@ def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01,
|
|||||||
# layers[0] = layers[0].to(dev)
|
# layers[0] = layers[0].to(dev)
|
||||||
|
|
||||||
dtype = next(iter(model.parameters())).dtype
|
dtype = next(iter(model.parameters())).dtype
|
||||||
inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
|
inps = torch.zeros(
|
||||||
cache = {'i': 0, 'attention_mask': None}
|
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
|
||||||
|
)
|
||||||
|
|
||||||
|
cache = {"i": 0}
|
||||||
|
extra = {}
|
||||||
|
|
||||||
class Catcher(nn.Module):
|
class Catcher(nn.Module):
|
||||||
|
|
||||||
def __init__(self, module):
|
def __init__(self, module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module = module
|
self.module = module
|
||||||
|
|
||||||
def forward(self, inp, **kwargs):
|
def forward(self, inp, **kwargs):
|
||||||
inps[cache['i']] = inp
|
inps[cache["i"]] = inp
|
||||||
cache['i'] += 1
|
cache["i"] += 1
|
||||||
cache['attention_mask'] = kwargs['attention_mask']
|
extra.update(kwargs.copy())
|
||||||
cache['position_ids'] = kwargs['position_ids']
|
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
layers[0] = Catcher(layers[0])
|
layers[0] = Catcher(layers[0])
|
||||||
@ -558,19 +653,22 @@ def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01,
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
outs = torch.zeros_like(inps)
|
outs = torch.zeros_like(inps)
|
||||||
attention_mask = cache['attention_mask'].to(dev)
|
|
||||||
position_ids = cache['position_ids'].to(dev)
|
|
||||||
|
|
||||||
print('Ready.')
|
extra = {
|
||||||
|
k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
print("Ready.")
|
||||||
|
|
||||||
quantizers = {}
|
quantizers = {}
|
||||||
for i in range(len(layers)):
|
for i in range(len(layers)):
|
||||||
print(f'Quantizing layer {i+1}/{len(layers)}..')
|
print(f"Quantizing layer {i+1}/{len(layers)}..")
|
||||||
print('+------------------+--------------+------------+-----------+-------+')
|
print("+------------------+--------------+------------+-----------+-------+")
|
||||||
print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |')
|
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
|
||||||
print('+==================+==============+============+===========+=======+')
|
print("+==================+==============+============+===========+=======+")
|
||||||
|
|
||||||
from accelerate.hooks import remove_hook_from_submodules
|
from accelerate.hooks import remove_hook_from_submodules
|
||||||
|
|
||||||
layer = layers[i].to(dev)
|
layer = layers[i].to(dev)
|
||||||
remove_hook_from_submodules(layer)
|
remove_hook_from_submodules(layer)
|
||||||
full = find_layers(layer)
|
full = find_layers(layer)
|
||||||
@ -581,10 +679,11 @@ def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01,
|
|||||||
gptq = {}
|
gptq = {}
|
||||||
for name in subset:
|
for name in subset:
|
||||||
gptq[name] = GPTQ(subset[name])
|
gptq[name] = GPTQ(subset[name])
|
||||||
gptq[name].quantizer.configure(bits, perchannel=True, sym=sym, mse=False)
|
gptq[name].quantizer.configure(
|
||||||
|
bits, perchannel=True, sym=sym, mse=False
|
||||||
|
)
|
||||||
|
|
||||||
def add_batch(name):
|
def add_batch(name):
|
||||||
|
|
||||||
def tmp(_, inp, out):
|
def tmp(_, inp, out):
|
||||||
gptq[name].add_batch(inp[0].data, out.data)
|
gptq[name].add_batch(inp[0].data, out.data)
|
||||||
|
|
||||||
@ -595,18 +694,30 @@ def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01,
|
|||||||
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
||||||
for j in range(nsamples):
|
for j in range(nsamples):
|
||||||
|
|
||||||
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
||||||
for h in handles:
|
for h in handles:
|
||||||
h.remove()
|
h.remove()
|
||||||
|
|
||||||
for name in subset:
|
for name in subset:
|
||||||
scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=percdamp, groupsize=groupsize, actorder=act_order, name=name)
|
scale, zero, g_idx, error = gptq[name].fasterquant(
|
||||||
quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), bits, groupsize)
|
percdamp=percdamp,
|
||||||
|
groupsize=groupsize,
|
||||||
|
actorder=act_order,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
quantizers[f"{prefix}.{i}.{name}"] = (
|
||||||
|
gptq[name].quantizer.cpu(),
|
||||||
|
scale.cpu(),
|
||||||
|
zero.cpu(),
|
||||||
|
g_idx.cpu(),
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
)
|
||||||
|
|
||||||
gptq[name].free()
|
gptq[name].free()
|
||||||
|
|
||||||
for j in range(nsamples):
|
for j in range(nsamples):
|
||||||
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
||||||
|
|
||||||
layers[i] = layer.cpu()
|
layers[i] = layer.cpu()
|
||||||
del layer
|
del layer
|
||||||
@ -614,8 +725,8 @@ def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01,
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
inps, outs = outs, inps
|
inps, outs = outs, inps
|
||||||
print('+------------------+--------------+------------+-----------+-------+')
|
print("+------------------+--------------+------------+-----------+-------+")
|
||||||
print('\n')
|
print("\n")
|
||||||
|
|
||||||
# if args.observe:
|
# if args.observe:
|
||||||
# observer.print()
|
# observer.print()
|
||||||
@ -746,18 +857,30 @@ def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01,
|
|||||||
#
|
#
|
||||||
# model.config.use_cache = use_cache
|
# model.config.use_cache = use_cache
|
||||||
|
|
||||||
def make_quant_linear(module, names, bits, groupsize, name=''):
|
|
||||||
|
def make_quant_linear(module, names, bits, groupsize, name=""):
|
||||||
if isinstance(module, QuantLinear):
|
if isinstance(module, QuantLinear):
|
||||||
return
|
return
|
||||||
for attr in dir(module):
|
for attr in dir(module):
|
||||||
tmp = getattr(module, attr)
|
tmp = getattr(module, attr)
|
||||||
name1 = name + '.' + attr if name != '' else attr
|
name1 = name + "." + attr if name != "" else attr
|
||||||
if name1 in names:
|
if name1 in names:
|
||||||
delattr(module, attr)
|
delattr(module, attr)
|
||||||
setattr(module, attr, QuantLinear.new(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
|
setattr(
|
||||||
|
module,
|
||||||
|
attr,
|
||||||
|
QuantLinear.new(
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
tmp.in_features,
|
||||||
|
tmp.out_features,
|
||||||
|
tmp.bias is not None,
|
||||||
|
),
|
||||||
|
)
|
||||||
for name1, child in module.named_children():
|
for name1, child in module.named_children():
|
||||||
make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
|
make_quant_linear(
|
||||||
|
child, names, bits, groupsize, name + "." + name1 if name != "" else name1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: perform packing on GPU
|
# TODO: perform packing on GPU
|
||||||
@ -765,13 +888,13 @@ def pack(model, quantizers, bits, groupsize):
|
|||||||
layers = find_layers(model)
|
layers = find_layers(model)
|
||||||
layers = {n: layers[n] for n in quantizers}
|
layers = {n: layers[n] for n in quantizers}
|
||||||
make_quant_linear(model, quantizers, bits, groupsize)
|
make_quant_linear(model, quantizers, bits, groupsize)
|
||||||
qlayers = find_layers(model, [QuantLinear])
|
qlayers = find_layers(model, (QuantLinear,))
|
||||||
print('Packing ...')
|
print("Packing ...")
|
||||||
for name in qlayers:
|
for name in qlayers:
|
||||||
print(name)
|
print(name)
|
||||||
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
|
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
|
||||||
qlayers[name].pack(layers[name], scale, zero, g_idx)
|
qlayers[name].pack(layers[name], scale, zero, g_idx)
|
||||||
print('Done.')
|
print("Done.")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -937,9 +1060,16 @@ def pack(model, quantizers, bits, groupsize):
|
|||||||
# print('max memory(MiB):', max_memory)
|
# print('max memory(MiB):', max_memory)
|
||||||
|
|
||||||
|
|
||||||
def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_remote_code: bool):
|
def quantize(
|
||||||
|
model_id: str, bits: int, groupsize: int, output_dir: str, trust_remote_code: bool
|
||||||
|
):
|
||||||
print("loading model")
|
print("loading model")
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0", trust_remote_code=trust_remote_code)
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="balanced_low_0",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
print("LOADED model")
|
print("LOADED model")
|
||||||
model.seqlen = 2048
|
model.seqlen = 2048
|
||||||
|
|
||||||
@ -947,8 +1077,9 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_re
|
|||||||
nsamples = 128
|
nsamples = 128
|
||||||
seed = None
|
seed = None
|
||||||
|
|
||||||
|
dataloader, testloader = get_loaders(
|
||||||
dataloader, testloader = get_loaders(dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen)
|
dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen
|
||||||
|
)
|
||||||
|
|
||||||
tick = time.time()
|
tick = time.time()
|
||||||
quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize)
|
quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize)
|
||||||
@ -988,7 +1119,6 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_re
|
|||||||
# generated_ids = model.generate(input_ids, streamer=streamer)
|
# generated_ids = model.generate(input_ids, streamer=streamer)
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
# if args.quant_directory is not None:
|
# if args.quant_directory is not None:
|
||||||
# export_quant_table(quantizers, args.quant_directory)
|
# export_quant_table(quantizers, args.quant_directory)
|
||||||
|
|
||||||
@ -1000,16 +1130,27 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_re
|
|||||||
pack(model, quantizers, bits, groupsize)
|
pack(model, quantizers, bits, groupsize)
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
from transformers.modeling_utils import shard_checkpoint
|
from transformers.modeling_utils import shard_checkpoint
|
||||||
|
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||||
state_dict["gptq_bits"] = torch.LongTensor([bits])
|
state_dict["gptq_bits"] = torch.LongTensor([bits])
|
||||||
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
|
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
|
||||||
|
|
||||||
max_shard_size = "10GB"
|
max_shard_size = "10GB"
|
||||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors")
|
shards, index = shard_checkpoint(
|
||||||
|
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
|
||||||
|
)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
for shard_file, shard in shards.items():
|
for shard_file, shard in shards.items():
|
||||||
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt", "quantized": "gptq", "origin": "text-generation-inference"})
|
save_file(
|
||||||
|
shard,
|
||||||
|
os.path.join(output_dir, shard_file),
|
||||||
|
metadata={
|
||||||
|
"format": "pt",
|
||||||
|
"quantized": "gptq",
|
||||||
|
"origin": "text-generation-inference",
|
||||||
|
},
|
||||||
|
)
|
||||||
if index is None:
|
if index is None:
|
||||||
path_to_weights = os.path.join(output_dir, "model.safetensors")
|
path_to_weights = os.path.join(output_dir, "model.safetensors")
|
||||||
logger.info(f"Model weights saved in {path_to_weights}")
|
logger.info(f"Model weights saved in {path_to_weights}")
|
||||||
|
Loading…
Reference in New Issue
Block a user