mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Adding quantization scripts.
This commit is contained in:
parent
da8ebf16fe
commit
5a72715344
@ -150,6 +150,32 @@ def download_weights(
|
||||
# Convert pytorch weights to safetensors
|
||||
utils.convert_files(local_pt_files, local_st_files)
|
||||
|
||||
@app.command()
|
||||
def quantize(
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
):
|
||||
extension: str = ".safetensors",
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format="{message}",
|
||||
filter="text_generation_server",
|
||||
level=logger_level,
|
||||
serialize=json_output,
|
||||
backtrace=True,
|
||||
diagnose=False,
|
||||
)
|
||||
download_weights(model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output)
|
||||
from text_generation_server.utils.gptq.quantize import quantize
|
||||
quantize(model_id=model_id, wbits=4, groupsize=128)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
@ -247,6 +247,8 @@ def get_model(
|
||||
|
||||
if sharded:
|
||||
raise ValueError("sharded is not supported for AutoModel")
|
||||
if quantize == "gptq":
|
||||
raise ValueError("gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
|
@ -42,7 +42,8 @@ from text_generation_server.utils.layers import (
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
@ -57,9 +58,9 @@ def load_row(config, prefix: str, weights, bias: bool):
|
||||
|
||||
|
||||
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
|
||||
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize)
|
||||
if isinstance(weight, torch.Tensor):
|
||||
# Only on non quantized versions
|
||||
weight = (
|
||||
weight.view(
|
||||
num_heads,
|
||||
@ -70,8 +71,11 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||
.permute(1, 0, 2, 3)
|
||||
.reshape(-1, hidden_size)
|
||||
)
|
||||
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
|
||||
|
||||
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
if config.use_parallel_residual:
|
||||
return linear
|
||||
|
989
server/text_generation_server/utils/gptq/quantize.py
Normal file
989
server/text_generation_server/utils/gptq/quantize.py
Normal file
@ -0,0 +1,989 @@
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from texttable import Texttable
|
||||
from transformers import AutoModelForCausalLM
|
||||
import transformers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
DEV = torch.device("cuda:0")
|
||||
|
||||
|
||||
class Quantizer(nn.Module):
|
||||
|
||||
def __init__(self, shape=1):
|
||||
super(Quantizer, self).__init__()
|
||||
self.register_buffer('maxq', torch.tensor(0))
|
||||
self.register_buffer('scale', torch.zeros(shape))
|
||||
self.register_buffer('zero', torch.zeros(shape))
|
||||
|
||||
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False):
|
||||
|
||||
self.maxq = torch.tensor(2**bits - 1)
|
||||
self.perchannel = perchannel
|
||||
self.sym = sym
|
||||
self.mse = mse
|
||||
self.norm = norm
|
||||
self.grid = grid
|
||||
self.maxshrink = maxshrink
|
||||
if trits:
|
||||
self.maxq = torch.tensor(-1)
|
||||
self.scale = torch.zeros_like(self.scale)
|
||||
|
||||
def _quantize(self, x, scale, zero, maxq):
|
||||
if maxq < 0:
|
||||
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
|
||||
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
||||
return scale * (q - zero)
|
||||
|
||||
def find_params(self, x, weight=False):
|
||||
dev = x.device
|
||||
self.maxq = self.maxq.to(dev)
|
||||
|
||||
shape = x.shape
|
||||
if self.perchannel:
|
||||
if weight:
|
||||
x = x.flatten(1)
|
||||
else:
|
||||
if len(shape) == 4:
|
||||
x = x.permute([1, 0, 2, 3])
|
||||
x = x.flatten(1)
|
||||
if len(shape) == 3:
|
||||
x = x.reshape((-1, shape[-1])).t()
|
||||
if len(shape) == 2:
|
||||
x = x.t()
|
||||
else:
|
||||
x = x.flatten().unsqueeze(0)
|
||||
|
||||
tmp = torch.zeros(x.shape[0], device=dev)
|
||||
xmin = torch.minimum(x.min(1)[0], tmp)
|
||||
xmax = torch.maximum(x.max(1)[0], tmp)
|
||||
|
||||
if self.sym:
|
||||
xmax = torch.maximum(torch.abs(xmin), xmax)
|
||||
tmp = xmin < 0
|
||||
if torch.any(tmp):
|
||||
xmin[tmp] = -xmax[tmp]
|
||||
tmp = (xmin == 0) & (xmax == 0)
|
||||
xmin[tmp] = -1
|
||||
xmax[tmp] = +1
|
||||
|
||||
if self.maxq < 0:
|
||||
self.scale = xmax
|
||||
self.zero = xmin
|
||||
else:
|
||||
self.scale = (xmax - xmin) / self.maxq
|
||||
if self.sym:
|
||||
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
|
||||
else:
|
||||
self.zero = torch.round(-xmin / self.scale)
|
||||
|
||||
if self.mse:
|
||||
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
||||
for i in range(int(self.maxshrink * self.grid)):
|
||||
p = 1 - i / self.grid
|
||||
xmin1 = p * xmin
|
||||
xmax1 = p * xmax
|
||||
scale1 = (xmax1 - xmin1) / self.maxq
|
||||
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
||||
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
|
||||
q -= x
|
||||
q.abs_()
|
||||
q.pow_(self.norm)
|
||||
err = torch.sum(q, 1)
|
||||
tmp = err < best
|
||||
if torch.any(tmp):
|
||||
best[tmp] = err[tmp]
|
||||
self.scale[tmp] = scale1[tmp]
|
||||
self.zero[tmp] = zero1[tmp]
|
||||
if not self.perchannel:
|
||||
if weight:
|
||||
tmp = shape[0]
|
||||
else:
|
||||
tmp = shape[1] if len(shape) != 3 else shape[2]
|
||||
self.scale = self.scale.repeat(tmp)
|
||||
self.zero = self.zero.repeat(tmp)
|
||||
|
||||
if weight:
|
||||
shape = [-1] + [1] * (len(shape) - 1)
|
||||
self.scale = self.scale.reshape(shape)
|
||||
self.zero = self.zero.reshape(shape)
|
||||
return
|
||||
if len(shape) == 4:
|
||||
self.scale = self.scale.reshape((1, -1, 1, 1))
|
||||
self.zero = self.zero.reshape((1, -1, 1, 1))
|
||||
if len(shape) == 3:
|
||||
self.scale = self.scale.reshape((1, 1, -1))
|
||||
self.zero = self.zero.reshape((1, 1, -1))
|
||||
if len(shape) == 2:
|
||||
self.scale = self.scale.unsqueeze(0)
|
||||
self.zero = self.zero.unsqueeze(0)
|
||||
|
||||
def quantize(self, x):
|
||||
if self.ready():
|
||||
return self._quantize(x, self.scale, self.zero, self.maxq)
|
||||
|
||||
return x
|
||||
|
||||
def enabled(self):
|
||||
return self.maxq > 0
|
||||
|
||||
def ready(self):
|
||||
return torch.all(self.scale != 0)
|
||||
|
||||
|
||||
class GPTQ:
|
||||
|
||||
def __init__(self, layer, observe=False):
|
||||
self.layer = layer
|
||||
self.dev = self.layer.weight.device
|
||||
W = layer.weight.data.clone()
|
||||
if isinstance(self.layer, nn.Conv2d):
|
||||
W = W.flatten(1)
|
||||
if isinstance(self.layer, transformers.Conv1D):
|
||||
W = W.t()
|
||||
self.rows = W.shape[0]
|
||||
self.columns = W.shape[1]
|
||||
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
|
||||
self.nsamples = 0
|
||||
self.quantizer = Quantizer()
|
||||
self.observe = observe
|
||||
|
||||
def add_batch(self, inp, out):
|
||||
# Hessian H = 2 X XT + λ I
|
||||
if self.observe:
|
||||
self.inp1 = inp
|
||||
self.out1 = out
|
||||
else:
|
||||
self.inp1 = None
|
||||
self.out1 = None
|
||||
|
||||
if len(inp.shape) == 2:
|
||||
inp = inp.unsqueeze(0)
|
||||
tmp = inp.shape[0]
|
||||
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
|
||||
if len(inp.shape) == 3:
|
||||
inp = inp.reshape((-1, inp.shape[-1]))
|
||||
inp = inp.t()
|
||||
if isinstance(self.layer, nn.Conv2d):
|
||||
unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride)
|
||||
inp = unfold(inp)
|
||||
inp = inp.permute([1, 0, 2])
|
||||
inp = inp.flatten(1)
|
||||
self.H *= self.nsamples / (self.nsamples + tmp)
|
||||
self.nsamples += tmp
|
||||
# inp = inp.float()
|
||||
inp = math.sqrt(2 / self.nsamples) * inp.float()
|
||||
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
|
||||
self.H += inp.matmul(inp.t())
|
||||
|
||||
def print_loss(self, name, q_weight, weight_error, timecost):
|
||||
table = Texttable()
|
||||
name += ' ' * (16 - len(name))
|
||||
|
||||
table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time'])
|
||||
|
||||
# assign weight
|
||||
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
|
||||
|
||||
if self.inp1 is not None:
|
||||
# quantize input to int8
|
||||
quantizer = Quantizer()
|
||||
quantizer.configure(8, perchannel=False, sym=True, mse=False)
|
||||
quantizer.find_params(self.inp1)
|
||||
q_in = quantizer.quantize(self.inp1).type(torch.float16)
|
||||
q_out = self.layer(q_in)
|
||||
|
||||
# get kinds of SNR
|
||||
q_SNR = torch_snr_error(q_out, self.out1).item()
|
||||
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
|
||||
else:
|
||||
q_SNR = '-'
|
||||
fp_SNR = '-'
|
||||
|
||||
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
|
||||
print(table.draw().split('\n')[-2])
|
||||
|
||||
def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''):
|
||||
self.layer.to(self.dev)
|
||||
|
||||
W = self.layer.weight.data.clone()
|
||||
if isinstance(self.layer, nn.Conv2d):
|
||||
W = W.flatten(1)
|
||||
if isinstance(self.layer, transformers.Conv1D):
|
||||
W = W.t()
|
||||
W = W.float()
|
||||
|
||||
tick = time.time()
|
||||
|
||||
if not self.quantizer.ready():
|
||||
self.quantizer.find_params(W, weight=True)
|
||||
|
||||
H = self.H
|
||||
if not self.observe:
|
||||
del self.H
|
||||
dead = torch.diag(H) == 0
|
||||
H[dead, dead] = 1
|
||||
W[:, dead] = 0
|
||||
|
||||
if actorder:
|
||||
perm = torch.argsort(torch.diag(H), descending=True)
|
||||
W = W[:, perm]
|
||||
H = H[perm][:, perm]
|
||||
|
||||
Losses = torch.zeros_like(W)
|
||||
Q = torch.zeros_like(W)
|
||||
|
||||
damp = percdamp * torch.mean(torch.diag(H))
|
||||
diag = torch.arange(self.columns, device=self.dev)
|
||||
H[diag, diag] += damp
|
||||
H = torch.linalg.cholesky(H)
|
||||
H = torch.cholesky_inverse(H)
|
||||
H = torch.linalg.cholesky(H, upper=True)
|
||||
Hinv = H
|
||||
|
||||
g_idx = []
|
||||
scale = []
|
||||
zero = []
|
||||
now_idx = 1
|
||||
|
||||
for i1 in range(0, self.columns, blocksize):
|
||||
i2 = min(i1 + blocksize, self.columns)
|
||||
count = i2 - i1
|
||||
|
||||
W1 = W[:, i1:i2].clone()
|
||||
Q1 = torch.zeros_like(W1)
|
||||
Err1 = torch.zeros_like(W1)
|
||||
Losses1 = torch.zeros_like(W1)
|
||||
Hinv1 = Hinv[i1:i2, i1:i2]
|
||||
|
||||
for i in range(count):
|
||||
w = W1[:, i]
|
||||
d = Hinv1[i, i]
|
||||
|
||||
if groupsize != -1:
|
||||
if (i1 + i) % groupsize == 0:
|
||||
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
|
||||
|
||||
if ((i1 + i) // groupsize) - now_idx == -1:
|
||||
scale.append(self.quantizer.scale)
|
||||
zero.append(self.quantizer.zero)
|
||||
now_idx += 1
|
||||
|
||||
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
|
||||
Q1[:, i] = q
|
||||
Losses1[:, i] = (w - q)**2 / d**2
|
||||
|
||||
err1 = (w - q) / d
|
||||
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
||||
Err1[:, i] = err1
|
||||
|
||||
Q[:, i1:i2] = Q1
|
||||
Losses[:, i1:i2] = Losses1 / 2
|
||||
|
||||
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
|
||||
|
||||
torch.cuda.synchronize()
|
||||
error = torch.sum(Losses).item()
|
||||
|
||||
groupsize = groupsize if groupsize != -1 else self.columns
|
||||
g_idx = [i // groupsize for i in range(self.columns)]
|
||||
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
||||
if actorder:
|
||||
invperm = torch.argsort(perm)
|
||||
Q = Q[:, invperm]
|
||||
g_idx = g_idx[invperm]
|
||||
|
||||
if isinstance(self.layer, transformers.Conv1D):
|
||||
Q = Q.t()
|
||||
|
||||
self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick))
|
||||
|
||||
if scale == []:
|
||||
scale.append(self.quantizer.scale)
|
||||
zero.append(self.quantizer.zero)
|
||||
scale = torch.cat(scale, dim=1)
|
||||
zero = torch.cat(zero, dim=1)
|
||||
return scale, zero, g_idx, error
|
||||
|
||||
def free(self):
|
||||
self.inp1 = None
|
||||
self.out1 = None
|
||||
self.H = None
|
||||
self.Losses = None
|
||||
self.Trace = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def get_wikitext2(nsamples, seed, seqlen, model_id):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
|
||||
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
|
||||
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
|
||||
|
||||
import random
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||
j = i + seqlen
|
||||
inp = trainenc.input_ids[:, i:j]
|
||||
tar = inp.clone()
|
||||
tar[:, :-1] = -100
|
||||
trainloader.append((inp, tar))
|
||||
return trainloader, testenc
|
||||
|
||||
|
||||
def get_ptb(nsamples, seed, seqlen, model_id):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
|
||||
valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt')
|
||||
testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt')
|
||||
|
||||
import random
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||
j = i + seqlen
|
||||
inp = trainenc.input_ids[:, i:j]
|
||||
tar = inp.clone()
|
||||
tar[:, :-1] = -100
|
||||
trainloader.append((inp, tar))
|
||||
return trainloader, testenc
|
||||
|
||||
|
||||
def get_c4(nsamples, seed, seqlen, model_id):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False)
|
||||
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False)
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
|
||||
import random
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
while True:
|
||||
i = random.randint(0, len(traindata) - 1)
|
||||
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
|
||||
if trainenc.input_ids.shape[1] >= seqlen:
|
||||
break
|
||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||
j = i + seqlen
|
||||
inp = trainenc.input_ids[:, i:j]
|
||||
tar = inp.clone()
|
||||
tar[:, :-1] = -100
|
||||
trainloader.append((inp, tar))
|
||||
|
||||
import random
|
||||
random.seed(0)
|
||||
valenc = []
|
||||
for _ in range(256):
|
||||
while True:
|
||||
i = random.randint(0, len(valdata) - 1)
|
||||
tmp = tokenizer(valdata[i]['text'], return_tensors='pt')
|
||||
if tmp.input_ids.shape[1] >= seqlen:
|
||||
break
|
||||
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
|
||||
j = i + seqlen
|
||||
valenc.append(tmp.input_ids[:, i:j])
|
||||
valenc = torch.hstack(valenc)
|
||||
|
||||
class TokenizerWrapper:
|
||||
|
||||
def __init__(self, input_ids):
|
||||
self.input_ids = input_ids
|
||||
|
||||
valenc = TokenizerWrapper(valenc)
|
||||
|
||||
return trainloader, valenc
|
||||
|
||||
|
||||
def get_ptb_new(nsamples, seed, seqlen, model_id):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
|
||||
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
|
||||
testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
|
||||
|
||||
import random
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||
j = i + seqlen
|
||||
inp = trainenc.input_ids[:, i:j]
|
||||
tar = inp.clone()
|
||||
tar[:, :-1] = -100
|
||||
trainloader.append((inp, tar))
|
||||
return trainloader, testenc
|
||||
|
||||
|
||||
def get_c4_new(nsamples, seed, seqlen, model_id):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
|
||||
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
|
||||
import random
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
while True:
|
||||
i = random.randint(0, len(traindata) - 1)
|
||||
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
|
||||
if trainenc.input_ids.shape[1] >= seqlen:
|
||||
break
|
||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||
j = i + seqlen
|
||||
inp = trainenc.input_ids[:, i:j]
|
||||
tar = inp.clone()
|
||||
tar[:, :-1] = -100
|
||||
trainloader.append((inp, tar))
|
||||
|
||||
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
|
||||
valenc = valenc.input_ids[:, :(256 * seqlen)]
|
||||
|
||||
class TokenizerWrapper:
|
||||
|
||||
def __init__(self, input_ids):
|
||||
self.input_ids = input_ids
|
||||
|
||||
valenc = TokenizerWrapper(valenc)
|
||||
|
||||
return trainloader, valenc
|
||||
|
||||
|
||||
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=''):
|
||||
if 'wikitext2' in name:
|
||||
return get_wikitext2(nsamples, seed, seqlen, model_id)
|
||||
if 'ptb' in name:
|
||||
if 'new' in name:
|
||||
return get_ptb_new(nsamples, seed, seqlen, model_id)
|
||||
return get_ptb(nsamples, seed, seqlen, model_id)
|
||||
if 'c4' in name:
|
||||
if 'new' in name:
|
||||
return get_c4_new(nsamples, seed, seqlen, model_id)
|
||||
return get_c4(nsamples, seed, seqlen, model_id)
|
||||
|
||||
|
||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
||||
# Skip last lm_head linear
|
||||
if type(module) in layers and "lm_head" not in name:
|
||||
return {name: module}
|
||||
res = {}
|
||||
for name1, child in module.named_children():
|
||||
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
|
||||
return res
|
||||
|
||||
@torch.no_grad()
|
||||
def sequential(model, dataloader, dev, nsamples, wbits, groupsize, percdamp=0.01, sym: bool=False, act_order: bool = False):
|
||||
print('Starting ...')
|
||||
|
||||
use_cache = model.config.use_cache
|
||||
model.config.use_cache = False
|
||||
layers = model.model.layers
|
||||
|
||||
# embeddings = model.get_input_embeddings()
|
||||
# embeddings = embeddings.to(dev)
|
||||
# model.set_input_embeddings(embeddings)
|
||||
# model.model.embed_tokens = model.model.embed_tokens.to(dev)
|
||||
# model.model.norm = model.model.norm.to(dev)
|
||||
# layers[0] = layers[0].to(dev)
|
||||
|
||||
dtype = next(iter(model.parameters())).dtype
|
||||
inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
|
||||
cache = {'i': 0, 'attention_mask': None}
|
||||
|
||||
class Catcher(nn.Module):
|
||||
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self, inp, **kwargs):
|
||||
inps[cache['i']] = inp
|
||||
cache['i'] += 1
|
||||
cache['attention_mask'] = kwargs['attention_mask']
|
||||
cache['position_ids'] = kwargs['position_ids']
|
||||
raise ValueError
|
||||
|
||||
layers[0] = Catcher(layers[0])
|
||||
for batch in dataloader:
|
||||
try:
|
||||
model(batch[0])
|
||||
except ValueError:
|
||||
pass
|
||||
layers[0] = layers[0].module
|
||||
|
||||
# layers[0] = layers[0].cpu()
|
||||
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
||||
# model.model.norm = model.model.norm.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
outs = torch.zeros_like(inps)
|
||||
attention_mask = cache['attention_mask'].to(dev)
|
||||
position_ids = cache['position_ids'].to(dev)
|
||||
|
||||
print('Ready.')
|
||||
|
||||
quantizers = {}
|
||||
for i in range(len(layers)):
|
||||
|
||||
print(f'Quantizing layer {i+1}/{len(layers)}..')
|
||||
print('+------------------+--------------+------------+-----------+-------+')
|
||||
print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |')
|
||||
print('+==================+==============+============+===========+=======+')
|
||||
|
||||
from accelerate.hooks import remove_hook_from_submodules
|
||||
layer = layers[i].to(dev)
|
||||
remove_hook_from_submodules(layer)
|
||||
full = find_layers(layer)
|
||||
sequential = [list(full.keys())]
|
||||
|
||||
for names in sequential:
|
||||
subset = {n: full[n] for n in names}
|
||||
gptq = {}
|
||||
for name in subset:
|
||||
gptq[name] = GPTQ(subset[name])
|
||||
gptq[name].quantizer.configure(wbits, perchannel=True, sym=sym, mse=False)
|
||||
|
||||
def add_batch(name):
|
||||
|
||||
def tmp(_, inp, out):
|
||||
gptq[name].add_batch(inp[0].data, out.data)
|
||||
|
||||
return tmp
|
||||
|
||||
handles = []
|
||||
for name in subset:
|
||||
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
||||
for j in range(nsamples):
|
||||
|
||||
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
for name in subset:
|
||||
scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=percdamp, groupsize=groupsize, actorder=act_order, name=name)
|
||||
quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize)
|
||||
|
||||
gptq[name].free()
|
||||
|
||||
for j in range(nsamples):
|
||||
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
||||
|
||||
layers[i] = layer.cpu()
|
||||
del layer
|
||||
del gptq
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
inps, outs = outs, inps
|
||||
print('+------------------+--------------+------------+-----------+-------+')
|
||||
print('\n')
|
||||
|
||||
# if args.observe:
|
||||
# observer.print()
|
||||
# conditions = gen_conditions(args.wbits, args.groupsize)
|
||||
# for item in observer.items():
|
||||
# name = item[0]
|
||||
# layerid = item[1]
|
||||
# gptq = item[2]['gptq']
|
||||
# error = item[2]['error']
|
||||
# target = error / 2
|
||||
|
||||
# table = Texttable()
|
||||
# table.header(['wbits', 'groupsize', 'error'])
|
||||
# table.set_cols_dtype(['i', 'i', 'f'])
|
||||
# table.add_row([args.wbits, args.groupsize, error])
|
||||
|
||||
# print('Optimizing {} {} ..'.format(name, layerid))
|
||||
# for wbits, groupsize in conditions:
|
||||
|
||||
# if error < target:
|
||||
# # if error dropped 50%, skip
|
||||
# break
|
||||
|
||||
# gptq.quantizer.configure(wbits, perchannel=True, sym=args.sym, mse=False)
|
||||
|
||||
# scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name)
|
||||
|
||||
# table.add_row([wbits, groupsize, error])
|
||||
# quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize)
|
||||
|
||||
# print(table.draw())
|
||||
# print('\n')
|
||||
# gptq.layer.to('cpu')
|
||||
# gptq.free()
|
||||
|
||||
model.config.use_cache = use_cache
|
||||
|
||||
return quantizers
|
||||
|
||||
|
||||
# @torch.no_grad()
|
||||
# def llama_eval(model, testenc, dev):
|
||||
# print('Evaluating ...')
|
||||
#
|
||||
# testenc = testenc.input_ids
|
||||
# nsamples = testenc.numel() // model.seqlen
|
||||
#
|
||||
# use_cache = model.config.use_cache
|
||||
# model.config.use_cache = False
|
||||
# layers = model.model.layers
|
||||
#
|
||||
# model.model.embed_tokens = model.model.embed_tokens.to(dev)
|
||||
# layers[0] = layers[0].to(dev)
|
||||
#
|
||||
# dtype = next(iter(model.parameters())).dtype
|
||||
# inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
|
||||
# cache = {'i': 0, 'attention_mask': None}
|
||||
#
|
||||
# class Catcher(nn.Module):
|
||||
#
|
||||
# def __init__(self, module):
|
||||
# super().__init__()
|
||||
# self.module = module
|
||||
#
|
||||
# def forward(self, inp, **kwargs):
|
||||
# inps[cache['i']] = inp
|
||||
# cache['i'] += 1
|
||||
# cache['attention_mask'] = kwargs['attention_mask']
|
||||
# cache['position_ids'] = kwargs['position_ids']
|
||||
# raise ValueError
|
||||
#
|
||||
# layers[0] = Catcher(layers[0])
|
||||
# for i in range(nsamples):
|
||||
# batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
|
||||
# try:
|
||||
# model(batch)
|
||||
# except ValueError:
|
||||
# pass
|
||||
# layers[0] = layers[0].module
|
||||
#
|
||||
# layers[0] = layers[0].cpu()
|
||||
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
||||
# torch.cuda.empty_cache()
|
||||
#
|
||||
# outs = torch.zeros_like(inps)
|
||||
# attention_mask = cache['attention_mask']
|
||||
# position_ids = cache['position_ids']
|
||||
#
|
||||
# for i in range(len(layers)):
|
||||
# print(i)
|
||||
# layer = layers[i].to(dev)
|
||||
#
|
||||
# if args.nearest:
|
||||
# subset = find_layers(layer)
|
||||
# for name in subset:
|
||||
# quantizer = quant.Quantizer()
|
||||
# quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
|
||||
# W = subset[name].weight.data
|
||||
# quantizer.find_params(W, weight=True)
|
||||
# subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
|
||||
#
|
||||
# for j in range(nsamples):
|
||||
# outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
||||
# layers[i] = layer.cpu()
|
||||
# del layer
|
||||
# torch.cuda.empty_cache()
|
||||
# inps, outs = outs, inps
|
||||
#
|
||||
# if model.model.norm is not None:
|
||||
# model.model.norm = model.model.norm.to(dev)
|
||||
# model.lm_head = model.lm_head.to(dev)
|
||||
#
|
||||
# testenc = testenc.to(dev)
|
||||
# nlls = []
|
||||
# for i in range(nsamples):
|
||||
# hidden_states = inps[i].unsqueeze(0)
|
||||
# if model.model.norm is not None:
|
||||
# hidden_states = model.model.norm(hidden_states)
|
||||
# lm_logits = model.lm_head(hidden_states)
|
||||
# shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
# shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
|
||||
# loss_fct = nn.CrossEntropyLoss()
|
||||
# loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
# neg_log_likelihood = loss.float() * model.seqlen
|
||||
# nlls.append(neg_log_likelihood)
|
||||
# ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
|
||||
# print(ppl.item())
|
||||
#
|
||||
# model.config.use_cache = use_cache
|
||||
|
||||
|
||||
# TODO: perform packing on GPU
|
||||
def pack(model, quantizers, wbits, groupsize):
|
||||
layers = find_layers(model)
|
||||
layers = {n: layers[n] for n in quantizers}
|
||||
quant.make_quant_linear(model, quantizers, wbits, groupsize)
|
||||
qlayers = find_layers(model, [QuantLinear])
|
||||
print('Packing ...')
|
||||
for name in qlayers:
|
||||
print(name)
|
||||
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
|
||||
qlayers[name].pack(layers[name], scale, zero, g_idx)
|
||||
print('Done.')
|
||||
return model
|
||||
|
||||
|
||||
# def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
|
||||
# from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils
|
||||
# config = LlamaConfig.from_pretrained(model)
|
||||
#
|
||||
# def noop(*args, **kwargs):
|
||||
# pass
|
||||
#
|
||||
# torch.nn.init.kaiming_uniform_ = noop
|
||||
# torch.nn.init.uniform_ = noop
|
||||
# torch.nn.init.normal_ = noop
|
||||
#
|
||||
# torch.set_default_dtype(torch.half)
|
||||
# modeling_utils._init_weights = False
|
||||
# torch.set_default_dtype(torch.half)
|
||||
# model = LlamaForCausalLM(config)
|
||||
# torch.set_default_dtype(torch.float)
|
||||
# if eval:
|
||||
# model = model.eval()
|
||||
# layers = find_layers(model)
|
||||
# for name in ['lm_head']:
|
||||
# if name in layers:
|
||||
# del layers[name]
|
||||
# quant.make_quant_linear(model, layers, wbits, groupsize)
|
||||
#
|
||||
# del layers
|
||||
#
|
||||
# print('Loading model ...')
|
||||
# if checkpoint.endswith('.safetensors'):
|
||||
# from safetensors.torch import load_file as safe_load
|
||||
# model.load_state_dict(safe_load(checkpoint))
|
||||
# else:
|
||||
# model.load_state_dict(torch.load(checkpoint))
|
||||
#
|
||||
# if eval:
|
||||
# quant.make_quant_attn(model)
|
||||
# quant.make_quant_norm(model)
|
||||
# if fused_mlp:
|
||||
# quant.make_fused_mlp(model)
|
||||
#
|
||||
# if warmup_autotune:
|
||||
# quant.autotune_warmup_linear(model, transpose=not (eval))
|
||||
# if eval and fused_mlp:
|
||||
# quant.autotune_warmup_fused(model)
|
||||
# model.seqlen = 2048
|
||||
# print('Done.')
|
||||
#
|
||||
# return model
|
||||
|
||||
|
||||
# def llama_multigpu(model, gpus, gpu_dist):
|
||||
# model.model.embed_tokens = model.model.embed_tokens.to(gpus[0])
|
||||
# if hasattr(model.model, 'norm') and model.model.norm:
|
||||
# model.model.norm = model.model.norm.to(gpus[0])
|
||||
# import copy
|
||||
# model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0])
|
||||
#
|
||||
# cache = {'mask': None, 'position_ids': None}
|
||||
#
|
||||
# class MoveModule(nn.Module):
|
||||
#
|
||||
# def __init__(self, module, invalidate_cache):
|
||||
# super().__init__()
|
||||
# self.module = module
|
||||
# self.dev = next(iter(self.module.parameters())).device
|
||||
# self.invalidate_cache=invalidate_cache
|
||||
#
|
||||
# def forward(self, *inp, **kwargs):
|
||||
# inp = list(inp)
|
||||
# if inp[0].device != self.dev:
|
||||
# inp[0] = inp[0].to(self.dev)
|
||||
#
|
||||
# if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache:
|
||||
# cache['mask'] = kwargs['attention_mask'].to(self.dev)
|
||||
# kwargs['attention_mask'] = cache['mask']
|
||||
#
|
||||
# if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache:
|
||||
# cache['position_ids'] = kwargs['position_ids'].to(self.dev)
|
||||
# kwargs['position_ids'] = cache['position_ids']
|
||||
#
|
||||
# tmp = self.module(*inp, **kwargs)
|
||||
# return tmp
|
||||
#
|
||||
# layers = model.model.layers
|
||||
# from math import ceil
|
||||
# if not gpu_dist:
|
||||
# pergpu = ceil(len(layers) / len(gpus))
|
||||
# for i in range(len(layers)):
|
||||
# layers[i] = MoveModule(layers[i].to(0 if i == 0 or i == len(layers) -1 else gpus[(i-1) // pergpu]), i==0)
|
||||
# else:
|
||||
# assert gpu_dist[0] >= 2, "At least two layers must be on GPU 0."
|
||||
# assigned_gpus = [0] * (gpu_dist[0]-1)
|
||||
# for i in range(1, len(gpu_dist)):
|
||||
# assigned_gpus = assigned_gpus + [i] * gpu_dist[i]
|
||||
#
|
||||
# remaining_assignments = len(layers)-len(assigned_gpus) - 1
|
||||
# if remaining_assignments > 0:
|
||||
# assigned_gpus = assigned_gpus + [-1] * remaining_assignments
|
||||
#
|
||||
# assigned_gpus = assigned_gpus + [0]
|
||||
#
|
||||
# for i in range(len(layers)):
|
||||
# layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0)
|
||||
#
|
||||
# model.gpus = gpus
|
||||
#
|
||||
#
|
||||
# def benchmark(model, input_ids, check=False):
|
||||
# input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
|
||||
# torch.cuda.synchronize()
|
||||
#
|
||||
# cache = {'past': None}
|
||||
#
|
||||
# def clear_past(i):
|
||||
#
|
||||
# def tmp(layer, inp, out):
|
||||
# if cache['past']:
|
||||
# cache['past'][i] = None
|
||||
#
|
||||
# return tmp
|
||||
#
|
||||
# for i, layer in enumerate(model.model.layers):
|
||||
# layer.register_forward_hook(clear_past(i))
|
||||
#
|
||||
# print('Benchmarking ...')
|
||||
#
|
||||
# if check:
|
||||
# loss = nn.CrossEntropyLoss()
|
||||
# tot = 0.
|
||||
#
|
||||
# def sync():
|
||||
# if hasattr(model, 'gpus'):
|
||||
# for gpu in model.gpus:
|
||||
# torch.cuda.synchronize(gpu)
|
||||
# else:
|
||||
# torch.cuda.synchronize()
|
||||
#
|
||||
# max_memory = 0
|
||||
# with torch.no_grad():
|
||||
# attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
|
||||
# times = []
|
||||
# for i in range(input_ids.numel()):
|
||||
# tick = time.time()
|
||||
# out = model(input_ids[:, i:i + 1], past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
|
||||
# sync()
|
||||
# times.append(time.time() - tick)
|
||||
# print(i, times[-1])
|
||||
# if hasattr(model, 'gpus'):
|
||||
# mem_allocated = sum(torch.cuda.memory_allocated(gpu) for gpu in model.gpus) / 1024 / 1024
|
||||
# else:
|
||||
# mem_allocated = torch.cuda.memory_allocated() / 1024 / 1024
|
||||
# max_memory = max(max_memory, mem_allocated)
|
||||
# if check and i != input_ids.numel() - 1:
|
||||
# tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
|
||||
# cache['past'] = list(out.past_key_values)
|
||||
# del out
|
||||
# sync()
|
||||
# print('Median:', np.median(times))
|
||||
# if check:
|
||||
# print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
|
||||
# print('max memory(MiB):', max_memory)
|
||||
|
||||
|
||||
def quantize(model_id: str, wbits: int, groupsize: int):
|
||||
print("loading model")
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0")
|
||||
print("LOADED model")
|
||||
model.seqlen = 2048
|
||||
|
||||
dataset = "wikitext2"
|
||||
nsamples = 128
|
||||
seed = None
|
||||
|
||||
|
||||
dataloader, testloader = get_loaders(dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen)
|
||||
|
||||
tick = time.time()
|
||||
quantizers = sequential(model, dataloader, DEV, nsamples, wbits, groupsize)
|
||||
print(time.time() - tick)
|
||||
|
||||
# if args.benchmark:
|
||||
# gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
|
||||
# if len(gpus) > 1:
|
||||
# llama_multigpu(model, gpus, gpu_dist)
|
||||
# else:
|
||||
# model = model.to(DEV)
|
||||
# if args.benchmark:
|
||||
# input_ids = next(iter(dataloader))[0][:, :args.benchmark]
|
||||
# benchmark(model, input_ids, check=args.check)
|
||||
|
||||
# if args.eval:
|
||||
# datasets = ['wikitext2', 'ptb', 'c4']
|
||||
# if args.new_eval:
|
||||
# datasets = ['wikitext2', 'ptb-new', 'c4-new']
|
||||
# for dataset in datasets:
|
||||
# dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
|
||||
# print(dataset)
|
||||
# llama_eval(model, testloader, DEV)
|
||||
#
|
||||
# if args.test_generation:
|
||||
# gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
|
||||
# if len(gpus) > 1:
|
||||
# llama_multigpu(model, gpus, gpu_dist)
|
||||
# else:
|
||||
# model = model.to(DEV)
|
||||
|
||||
# from transformers import LlamaTokenizer, TextStreamer
|
||||
# tokenizer = LlamaTokenizer.from_pretrained(args.model, use_fast=False)
|
||||
# input_ids = tokenizer(["The capital of New Mexico is"], return_tensors="pt").input_ids.to(gpus[0])
|
||||
# streamer = TextStreamer(tokenizer)
|
||||
# with torch.no_grad():
|
||||
# generated_ids = model.generate(input_ids, streamer=streamer)
|
||||
#
|
||||
|
||||
|
||||
# if args.quant_directory is not None:
|
||||
# export_quant_table(quantizers, args.quant_directory)
|
||||
|
||||
# if not args.observe and args.save:
|
||||
# llama_pack(model, quantizers, args.wbits, args.groupsize)
|
||||
# torch.save(model.state_dict(), args.save)
|
||||
|
||||
# if not args.observe and args.save_safetensors:
|
||||
pack(model, quantizers, wbits, groupsize)
|
||||
from safetensors.torch import save_file as safe_save
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
||||
safe_save(state_dict, args.save_safetensors)
|
||||
|
@ -217,30 +217,11 @@ class TensorParallelHead(SuperLayer):
|
||||
class TensorParallelColumnLinear(SuperLayer):
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
else:
|
||||
bias = None
|
||||
return cls(get_linear(weight, bias, config.quantize))
|
||||
return cls.load_multi(config, [prefix], weights, bias, dim=0)
|
||||
|
||||
@classmethod
|
||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||
if config.quantize == "gptq":
|
||||
qweight = torch.cat([weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
||||
qzeros = torch.cat([weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
|
||||
scales = torch.cat([weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
# TODO Get that from file to be more generic
|
||||
bits = 4
|
||||
groupsize = 128
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
else:
|
||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
weight = weights.get_multi_weight_col(prefixes, quantize=config.quantize)
|
||||
|
||||
if bias:
|
||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||
@ -258,19 +239,7 @@ class TensorParallelRowLinear(SuperLayer):
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
if config.quantize == "gptq":
|
||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
# TODO Get that from file to be more generic
|
||||
bits = 4
|
||||
groupsize = 128
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
weight = weights.get_multi_weight_row(prefix, quantize=config.quantize)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
|
@ -82,3 +82,44 @@ class Weights:
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str):
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
||||
except RuntimeError:
|
||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||
|
||||
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
|
||||
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
# TODO Get that from file to be more generic
|
||||
bits = 4
|
||||
groupsize = 128
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
return weight
|
||||
|
||||
def get_multi_self_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
# TODO Get that from file to be more generic
|
||||
bits = 4
|
||||
groupsize = 128
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
Loading…
Reference in New Issue
Block a user