Some fixes.

This commit is contained in:
Ubuntu 2023-06-13 14:08:37 +00:00 committed by Nicolas Patry
parent a0a194c391
commit ae308f88ec
4 changed files with 348 additions and 626 deletions

View File

@ -12,121 +12,66 @@ try:
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=4,
),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=3, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=4),
],
key=["M", "N", "K"],
key=['M', 'N', 'K'],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
'perf_model': None,
'top_k': None,
},
)
@triton.jit
def matmul_248_kernel(
a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
g_ptr,
M,
N,
K,
bits,
maxq,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
@ -134,7 +79,7 @@ try:
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
@ -152,15 +97,10 @@ try:
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + (
(offs_k[:, None] // infearure_per_bits) * stride_bk
+ offs_bn[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
@ -174,17 +114,13 @@ try:
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(
scales_ptrs + g_idx[:, None] * stride_scales
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(
zeros_ptrs + g_idx[:, None] * stride_zeros
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
@ -200,118 +136,61 @@ try:
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=4,
),
],
key=["M", "N", "K"],
nearest_power_of_two=True,
)
@custom_autotune.autotune(configs=[
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 256,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=3, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
nearest_power_of_two=True)
@triton.jit
def transpose_matmul_248_kernel(
a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
g_ptr,
M,
N,
K,
bits,
maxq,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,
stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16
@ -319,7 +198,7 @@ try:
C is of shape (M, K) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
@ -337,25 +216,16 @@ try:
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = offs_am[:, None] < M
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + (
(offs_bk[:, None] // infearure_per_bits) * stride_bk
+ offs_n[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = (
zeros_ptr
+ (offs_n[None, :] // infearure_per_bits)
+ g_idx[:, None] * stride_zeros
)
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits
@ -367,9 +237,9 @@ try:
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
@ -381,84 +251,36 @@ try:
a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask)
except:
print("triton not installed.")
print('triton not installed.')
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty(
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
)
grid = lambda META: (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
g_idx,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
)
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )
matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
return output
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output_dim = (qweight.shape[0] * 32) // bits
output = torch.empty(
(input.shape[0], output_dim), device=input.device, dtype=torch.float16
)
grid = lambda META: (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(output_dim, META["BLOCK_SIZE_K"]),
)
transpose_matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
g_idx,
input.shape[0],
qweight.shape[1],
output_dim,
bits,
maxq,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
)
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )
transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
return output
class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
@ -475,9 +297,7 @@ class QuantLinearFunction(torch.autograd.Function):
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = transpose_matmul248(
grad_output, qweight, scales, qzeros, g_idx, bits, maxq
)
grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None
@ -500,39 +320,72 @@ class QuantLinear(nn.Module):
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = torch.zeros(
(infeatures // 32 * self.bits, outfeatures), dtype=torch.int32
)
qzeros = torch.zeros(
(math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // self.groupsize for i in range(infeatures)], dtype=torch.int32
)
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 32 * bits), dtype=torch.int32)
scales = torch.zeros((math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16)
g_idx = torch.tensor([i // groupsize for i in range(infeatures)], dtype=torch.int32)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq,
)
out_shape = x.shape[:-1] + (self.outfeatures, )
out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -4,36 +4,30 @@ import numpy as np
import torch
import torch.nn as nn
import math
import json
import os
from texttable import Texttable
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import transformers
import numpy as np
import torch
from text_generation_server.utils.gptq.quant_linear import QuantLinear
from loguru import logger
DEV = torch.device("cuda:0")
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer("maxq", torch.tensor(0))
self.register_buffer("scale", torch.zeros(shape))
self.register_buffer("zero", torch.zeros(shape))
self.register_buffer('maxq', torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape))
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False):
def configure(
self,
bits,
perchannel=False,
sym=True,
mse=False,
norm=2.4,
grid=100,
maxshrink=0.8,
trits=False,
):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
@ -94,16 +88,14 @@ class Quantizer(nn.Module):
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float("inf"), device=dev)
best = torch.full([x.shape[0]], float('inf'), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = self._quantize(
x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
)
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
@ -150,6 +142,7 @@ class Quantizer(nn.Module):
class GPTQ:
def __init__(self, layer, observe=False):
self.layer = layer
self.dev = self.layer.weight.device
@ -177,19 +170,12 @@ class GPTQ:
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(
self.layer, transformers.Conv1D
):
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
if isinstance(self.layer, nn.Conv2d):
unfold = nn.Unfold(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.stride,
)
unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride)
inp = unfold(inp)
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)
@ -202,14 +188,12 @@ class GPTQ:
def print_loss(self, name, q_weight, weight_error, timecost):
table = Texttable()
name += " " * (16 - len(name))
name += ' ' * (16 - len(name))
table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])
table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time'])
# assign weight
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
self.layer.weight.data.dtype
)
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
if self.inp1 is not None:
# quantize input to int8
@ -223,15 +207,13 @@ class GPTQ:
q_SNR = torch_snr_error(q_out, self.out1).item()
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
else:
q_SNR = "-"
fp_SNR = "-"
q_SNR = '-'
fp_SNR = '-'
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
print(table.draw().split("\n")[-2])
print(table.draw().split('\n')[-2])
def fasterquant(
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name=""
):
def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''):
self.layer.to(self.dev)
W = self.layer.weight.data.clone()
@ -290,9 +272,7 @@ class GPTQ:
if groupsize != -1:
if (i1 + i) % groupsize == 0:
self.quantizer.find_params(
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
)
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
if ((i1 + i) // groupsize) - now_idx == -1:
scale.append(self.quantizer.scale)
@ -301,7 +281,7 @@ class GPTQ:
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
Losses1[:, i] = (w - q)**2 / d**2
err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
@ -326,9 +306,7 @@ class GPTQ:
if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
self.print_loss(
name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
)
self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick))
if scale == []:
scale.append(self.quantizer.scale)
@ -348,18 +326,15 @@ class GPTQ:
def get_wikitext2(nsamples, seed, seqlen, model_id):
from datasets import load_dataset
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
@ -374,21 +349,18 @@ def get_wikitext2(nsamples, seed, seqlen, model_id):
def get_ptb(nsamples, seed, seqlen, model_id):
from datasets import load_dataset
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt')
testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
@ -403,37 +375,22 @@ def get_ptb(nsamples, seed, seqlen, model_id):
def get_c4(nsamples, seed, seqlen, model_id):
from datasets import load_dataset
traindata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
split="train",
use_auth_token=False,
)
valdata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
split="validation",
use_auth_token=False,
)
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False)
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False)
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
@ -444,13 +401,12 @@ def get_c4(nsamples, seed, seqlen, model_id):
trainloader.append((inp, tar))
import random
random.seed(0)
valenc = []
for _ in range(256):
while True:
i = random.randint(0, len(valdata) - 1)
tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
tmp = tokenizer(valdata[i]['text'], return_tensors='pt')
if tmp.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
@ -459,6 +415,7 @@ def get_c4(nsamples, seed, seqlen, model_id):
valenc = torch.hstack(valenc)
class TokenizerWrapper:
def __init__(self, input_ids):
self.input_ids = input_ids
@ -469,21 +426,18 @@ def get_c4(nsamples, seed, seqlen, model_id):
def get_ptb_new(nsamples, seed, seqlen, model_id):
from datasets import load_dataset
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
@ -498,35 +452,22 @@ def get_ptb_new(nsamples, seed, seqlen, model_id):
def get_c4_new(nsamples, seed, seqlen, model_id):
from datasets import load_dataset
traindata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
split="train",
)
valdata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
split="validation",
)
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
@ -536,10 +477,11 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
tar[:, :-1] = -100
trainloader.append((inp, tar))
valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
valenc = valenc.input_ids[:, : (256 * seqlen)]
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
valenc = valenc.input_ids[:, :(256 * seqlen)]
class TokenizerWrapper:
def __init__(self, input_ids):
self.input_ids = input_ids
@ -548,46 +490,31 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
return trainloader, valenc
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""):
if "wikitext2" in name:
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=''):
if 'wikitext2' in name:
return get_wikitext2(nsamples, seed, seqlen, model_id)
if "ptb" in name:
if "new" in name:
if 'ptb' in name:
if 'new' in name:
return get_ptb_new(nsamples, seed, seqlen, model_id)
return get_ptb(nsamples, seed, seqlen, model_id)
if "c4" in name:
if "new" in name:
if 'c4' in name:
if 'new' in name:
return get_c4_new(nsamples, seed, seqlen, model_id)
return get_c4(nsamples, seed, seqlen, model_id)
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
# Skip last lm_head linear
if type(module) in layers and "lm_head" not in name:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(
find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
return res
@torch.no_grad()
def sequential(
model,
dataloader,
dev,
nsamples,
bits,
groupsize,
percdamp=0.01,
sym: bool = False,
act_order: bool = False,
):
print("Starting ...")
def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01, sym: bool=False, act_order: bool = False):
print('Starting ...')
use_cache = model.config.use_cache
model.config.use_cache = False
@ -601,21 +528,20 @@ def sequential(
# layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros(
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
)
cache = {"i": 0, "attention_mask": None}
inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
cache = {'i': 0, 'attention_mask': None}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache["i"]] = inp
cache["i"] += 1
cache["attention_mask"] = kwargs["attention_mask"]
cache["position_ids"] = kwargs["position_ids"]
inps[cache['i']] = inp
cache['i'] += 1
cache['attention_mask'] = kwargs['attention_mask']
cache['position_ids'] = kwargs['position_ids']
raise ValueError
layers[0] = Catcher(layers[0])
@ -632,20 +558,19 @@ def sequential(
torch.cuda.empty_cache()
outs = torch.zeros_like(inps)
attention_mask = cache["attention_mask"].to(dev)
position_ids = cache["position_ids"].to(dev)
attention_mask = cache['attention_mask'].to(dev)
position_ids = cache['position_ids'].to(dev)
print("Ready.")
print('Ready.')
quantizers = {}
for i in range(len(layers)):
print(f"Quantizing layer {i+1}/{len(layers)}..")
print("+------------------+--------------+------------+-----------+-------+")
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
print("+==================+==============+============+===========+=======+")
print(f'Quantizing layer {i+1}/{len(layers)}..')
print('+------------------+--------------+------------+-----------+-------+')
print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |')
print('+==================+==============+============+===========+=======+')
from accelerate.hooks import remove_hook_from_submodules
layer = layers[i].to(dev)
remove_hook_from_submodules(layer)
full = find_layers(layer)
@ -656,11 +581,10 @@ def sequential(
gptq = {}
for name in subset:
gptq[name] = GPTQ(subset[name])
gptq[name].quantizer.configure(
bits, perchannel=True, sym=sym, mse=False
)
gptq[name].quantizer.configure(bits, perchannel=True, sym=sym, mse=False)
def add_batch(name):
def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data)
@ -670,38 +594,19 @@ def sequential(
for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(nsamples):
outs[j] = layer(
inps[j].unsqueeze(0),
attention_mask=attention_mask,
position_ids=position_ids,
)[0]
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
for h in handles:
h.remove()
for name in subset:
scale, zero, g_idx, error = gptq[name].fasterquant(
percdamp=percdamp,
groupsize=groupsize,
actorder=act_order,
name=name,
)
quantizers["model.layers.%d.%s" % (i, name)] = (
gptq[name].quantizer.cpu(),
scale.cpu(),
zero.cpu(),
g_idx.cpu(),
bits,
groupsize,
)
scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=percdamp, groupsize=groupsize, actorder=act_order, name=name)
quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), bits, groupsize)
gptq[name].free()
for j in range(nsamples):
outs[j] = layer(
inps[j].unsqueeze(0),
attention_mask=attention_mask,
position_ids=position_ids,
)[0]
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
layers[i] = layer.cpu()
del layer
@ -709,8 +614,8 @@ def sequential(
torch.cuda.empty_cache()
inps, outs = outs, inps
print("+------------------+--------------+------------+-----------+-------+")
print("\n")
print('+------------------+--------------+------------+-----------+-------+')
print('\n')
# if args.observe:
# observer.print()
@ -754,34 +659,34 @@ def sequential(
# @torch.no_grad()
# def llama_eval(model, testenc, dev):
# print('Evaluating ...')
#
#
# testenc = testenc.input_ids
# nsamples = testenc.numel() // model.seqlen
#
#
# use_cache = model.config.use_cache
# model.config.use_cache = False
# layers = model.model.layers
#
#
# model.model.embed_tokens = model.model.embed_tokens.to(dev)
# layers[0] = layers[0].to(dev)
#
#
# dtype = next(iter(model.parameters())).dtype
# inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
# cache = {'i': 0, 'attention_mask': None}
#
#
# class Catcher(nn.Module):
#
#
# def __init__(self, module):
# super().__init__()
# self.module = module
#
#
# def forward(self, inp, **kwargs):
# inps[cache['i']] = inp
# cache['i'] += 1
# cache['attention_mask'] = kwargs['attention_mask']
# cache['position_ids'] = kwargs['position_ids']
# raise ValueError
#
#
# layers[0] = Catcher(layers[0])
# for i in range(nsamples):
# batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
@ -790,19 +695,19 @@ def sequential(
# except ValueError:
# pass
# layers[0] = layers[0].module
#
#
# layers[0] = layers[0].cpu()
# model.model.embed_tokens = model.model.embed_tokens.cpu()
# torch.cuda.empty_cache()
#
#
# outs = torch.zeros_like(inps)
# attention_mask = cache['attention_mask']
# position_ids = cache['position_ids']
#
#
# for i in range(len(layers)):
# print(i)
# layer = layers[i].to(dev)
#
#
# if args.nearest:
# subset = find_layers(layer)
# for name in subset:
@ -811,18 +716,18 @@ def sequential(
# W = subset[name].weight.data
# quantizer.find_params(W, weight=True)
# subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
#
#
# for j in range(nsamples):
# outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
# layers[i] = layer.cpu()
# del layer
# torch.cuda.empty_cache()
# inps, outs = outs, inps
#
#
# if model.model.norm is not None:
# model.model.norm = model.model.norm.to(dev)
# model.lm_head = model.lm_head.to(dev)
#
#
# testenc = testenc.to(dev)
# nlls = []
# for i in range(nsamples):
@ -838,33 +743,21 @@ def sequential(
# nlls.append(neg_log_likelihood)
# ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
# print(ppl.item())
#
#
# model.config.use_cache = use_cache
def make_quant_linear(module, names, bits, groupsize, name=""):
def make_quant_linear(module, names, bits, groupsize, name=''):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + "." + attr if name != "" else attr
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
delattr(module, attr)
setattr(
module,
attr,
QuantLinear.new(
bits,
groupsize,
tmp.in_features,
tmp.out_features,
tmp.bias is not None,
),
)
setattr(module, attr, QuantLinear.new(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
for name1, child in module.named_children():
make_quant_linear(
child, names, bits, groupsize, name + "." + name1 if name != "" else name1
)
make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
# TODO: perform packing on GPU
@ -873,26 +766,26 @@ def pack(model, quantizers, bits, groupsize):
layers = {n: layers[n] for n in quantizers}
make_quant_linear(model, quantizers, bits, groupsize)
qlayers = find_layers(model, [QuantLinear])
print("Packing ...")
print('Packing ...')
for name in qlayers:
print(name)
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
qlayers[name].pack(layers[name], scale, zero, g_idx)
print("Done.")
print('Done.')
return model
# def load_quant(model, checkpoint, bits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
# from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils
# config = LlamaConfig.from_pretrained(model)
#
#
# def noop(*args, **kwargs):
# pass
#
#
# torch.nn.init.kaiming_uniform_ = noop
# torch.nn.init.uniform_ = noop
# torch.nn.init.normal_ = noop
#
#
# torch.set_default_dtype(torch.half)
# modeling_utils._init_weights = False
# torch.set_default_dtype(torch.half)
@ -905,29 +798,29 @@ def pack(model, quantizers, bits, groupsize):
# if name in layers:
# del layers[name]
# quant.make_quant_linear(model, layers, bits, groupsize)
#
#
# del layers
#
#
# print('Loading model ...')
# if checkpoint.endswith('.safetensors'):
# from safetensors.torch import load_file as safe_load
# model.load_state_dict(safe_load(checkpoint))
# else:
# model.load_state_dict(torch.load(checkpoint))
#
#
# if eval:
# quant.make_quant_attn(model)
# quant.make_quant_norm(model)
# if fused_mlp:
# quant.make_fused_mlp(model)
#
#
# if warmup_autotune:
# quant.autotune_warmup_linear(model, transpose=not (eval))
# if eval and fused_mlp:
# quant.autotune_warmup_fused(model)
# model.seqlen = 2048
# print('Done.')
#
#
# return model
@ -937,33 +830,33 @@ def pack(model, quantizers, bits, groupsize):
# model.model.norm = model.model.norm.to(gpus[0])
# import copy
# model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0])
#
#
# cache = {'mask': None, 'position_ids': None}
#
#
# class MoveModule(nn.Module):
#
#
# def __init__(self, module, invalidate_cache):
# super().__init__()
# self.module = module
# self.dev = next(iter(self.module.parameters())).device
# self.invalidate_cache=invalidate_cache
#
#
# def forward(self, *inp, **kwargs):
# inp = list(inp)
# if inp[0].device != self.dev:
# inp[0] = inp[0].to(self.dev)
#
#
# if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache:
# cache['mask'] = kwargs['attention_mask'].to(self.dev)
# kwargs['attention_mask'] = cache['mask']
#
#
# if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache:
# cache['position_ids'] = kwargs['position_ids'].to(self.dev)
# kwargs['position_ids'] = cache['position_ids']
#
#
# tmp = self.module(*inp, **kwargs)
# return tmp
#
#
# layers = model.model.layers
# from math import ceil
# if not gpu_dist:
@ -975,49 +868,49 @@ def pack(model, quantizers, bits, groupsize):
# assigned_gpus = [0] * (gpu_dist[0]-1)
# for i in range(1, len(gpu_dist)):
# assigned_gpus = assigned_gpus + [i] * gpu_dist[i]
#
#
# remaining_assignments = len(layers)-len(assigned_gpus) - 1
# if remaining_assignments > 0:
# assigned_gpus = assigned_gpus + [-1] * remaining_assignments
#
#
# assigned_gpus = assigned_gpus + [0]
#
#
# for i in range(len(layers)):
# layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0)
#
#
# model.gpus = gpus
#
#
#
#
# def benchmark(model, input_ids, check=False):
# input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
# torch.cuda.synchronize()
#
#
# cache = {'past': None}
#
#
# def clear_past(i):
#
#
# def tmp(layer, inp, out):
# if cache['past']:
# cache['past'][i] = None
#
#
# return tmp
#
#
# for i, layer in enumerate(model.model.layers):
# layer.register_forward_hook(clear_past(i))
#
#
# print('Benchmarking ...')
#
#
# if check:
# loss = nn.CrossEntropyLoss()
# tot = 0.
#
#
# def sync():
# if hasattr(model, 'gpus'):
# for gpu in model.gpus:
# torch.cuda.synchronize(gpu)
# else:
# torch.cuda.synchronize()
#
#
# max_memory = 0
# with torch.no_grad():
# attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
@ -1046,9 +939,7 @@ def pack(model, quantizers, bits, groupsize):
def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
print("loading model")
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, device_map="balanced_low_0"
)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0")
print("LOADED model")
model.seqlen = 2048
@ -1056,9 +947,8 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
nsamples = 128
seed = None
dataloader, testloader = get_loaders(
dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen
)
dataloader, testloader = get_loaders(dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen)
tick = time.time()
quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize)
@ -1082,7 +972,7 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
# dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
# print(dataset)
# llama_eval(model, testloader, DEV)
#
#
# if args.test_generation:
# gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
# if len(gpus) > 1:
@ -1096,7 +986,8 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
# streamer = TextStreamer(tokenizer)
# with torch.no_grad():
# generated_ids = model.generate(input_ids, streamer=streamer)
#
#
# if args.quant_directory is not None:
# export_quant_table(quantizers, args.quant_directory)
@ -1109,32 +1000,22 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file
from transformers.modeling_utils import shard_checkpoint
state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
state_dict["gptq_bits"] = torch.LongTensor(bits)
state_dict["gptq_groupsize"] = torch.LongTensor(groupsize)
shards, index = shard_checkpoint(
state_dict, max_shard_size="10GB", weights_name="model.safetensors"
)
max_shard_size = "10GB"
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors")
os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items():
save_file(
shard,
os.path.join(output_dir, shard_file),
metadata={
"format": "pt",
"quantized": "gptq",
"origin": "text-generation-inference",
},
)
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt", "quantized": "gptq", "origin": "text-generation-inference"})
if index is None:
path_to_weights = os.path.join(save_directory, "model.safetensors")
path_to_weights = os.path.join(output_dir, "model.safetensors")
logger.info(f"Model weights saved in {path_to_weights}")
else:
save_index_file = "model.safetensors.index.json"
save_index_file = os.path.join(save_directory, save_index_file)
save_index_file = os.path.join(output_dir, save_index_file)
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

View File

@ -134,15 +134,13 @@ def get_linear(weight, bias, quantize):
try:
qweight, qzeros, scales, g_idx, bits, groupsize = weight
except Exception:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
raise NotImplementedError(f"The passed weight is not `gptq` compatible, loader needs to be updated.")
linear = QuantLinear(
qweight,
qzeros,
scales,
g_idx,
g_idx,
bias,
bits,
groupsize,
@ -223,7 +221,7 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
weight = weights.get_multi_weight_col(prefixes, quantize=config.quantize)
weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
@ -241,7 +239,7 @@ class TensorParallelRowLinear(SuperLayer):
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weight_row(prefix, quantize=config.quantize)
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process

View File

@ -86,20 +86,12 @@ class Weights:
def get_multi_weights_col(self, prefixes: List[str], quantize: str):
if quantize == "gptq":
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
@ -110,17 +102,15 @@ class Weights:
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
weight = torch.cat(w, dim=1)
return weight
def get_multi_self_row(self, prefix: str, quantize: str):
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)