diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index cd3c4d35..3093a700 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -12,121 +12,66 @@ try: # code based https://github.com/fpgaminer/GPTQ-triton @custom_autotune.autotune( configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=4, - ), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), ], - key=["M", "N", "K"], + key=['M', 'N', 'K'], nearest_power_of_two=True, prune_configs_by={ - "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, - "perf_model": None, - "top_k": None, + 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, }, ) @triton.jit - def matmul_248_kernel( - a_ptr, - b_ptr, - c_ptr, - scales_ptr, - zeros_ptr, - g_ptr, - M, - N, - K, - bits, - maxq, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_scales, - stride_zeros, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - ): + def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): """ Compute the matrix multiplication C = A x B. A is of shape (M, K) float16 @@ -134,7 +79,7 @@ try: C is of shape (M, N) float16 scales is of shape (G, N) float16 zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 + g_ptr is of shape (K) int32 """ infearure_per_bits = 32 // bits @@ -152,15 +97,10 @@ try: offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = offs_am[:, None] < M + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ( - (offs_k[:, None] // infearure_per_bits) * stride_bk - + offs_bn[None, :] * stride_bn - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) g_ptrs = g_ptr + offs_k # shifter is used to extract the N bits of each element in the 32-bit word from B scales_ptrs = scales_ptr + offs_bn[None, :] @@ -174,17 +114,13 @@ try: g_idx = tl.load(g_ptrs) # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load( - scales_ptrs + g_idx[:, None] * stride_scales - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load( - zeros_ptrs + g_idx[:, None] * stride_zeros - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = zeros + 1 + zeros = (zeros + 1) - a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated # Now we need to unpack b (which is N-bit values) into 32-bit values @@ -200,118 +136,61 @@ try: c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) - @custom_autotune.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=4, - ), - ], - key=["M", "N", "K"], - nearest_power_of_two=True, - ) + @custom_autotune.autotune(configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 256, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True) @triton.jit - def transpose_matmul_248_kernel( - a_ptr, - b_ptr, - c_ptr, - scales_ptr, - zeros_ptr, - g_ptr, - M, - N, - K, - bits, - maxq, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_scales, - stride_zeros, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - ): + def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, + stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): """ Compute the matrix multiplication C = A x B. A is of shape (M, N) float16 @@ -319,7 +198,7 @@ try: C is of shape (M, K) float16 scales is of shape (G, N) float16 zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 + g_ptr is of shape (K) int32 """ infearure_per_bits = 32 // bits @@ -337,25 +216,16 @@ try: offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) offs_n = tl.arange(0, BLOCK_SIZE_N) - a_ptrs = a_ptr + ( - offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak - ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) - a_mask = offs_am[:, None] < M + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + a_mask = (offs_am[:, None] < M) # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ( - (offs_bk[:, None] // infearure_per_bits) * stride_bk - + offs_n[None, :] * stride_bn - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) g_ptrs = g_ptr + offs_bk g_idx = tl.load(g_ptrs) # shifter is used to extract the N bits of each element in the 32-bit word from B scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales - zeros_ptrs = ( - zeros_ptr - + (offs_n[None, :] // infearure_per_bits) - + g_idx[:, None] * stride_zeros - ) + zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros shifter = (offs_bk % infearure_per_bits) * bits zeros_shifter = (offs_n % infearure_per_bits) * bits @@ -367,9 +237,9 @@ try: zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = zeros + 1 + zeros = (zeros + 1) - a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated # Now we need to unpack b (which is N-bit values) into 32-bit values @@ -381,84 +251,36 @@ try: a_ptrs += BLOCK_SIZE_N b_ptrs += BLOCK_SIZE_N scales_ptrs += BLOCK_SIZE_N - zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits + zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) tl.store(c_ptrs, accumulator, mask=c_mask) - except: - print("triton not installed.") + print('triton not installed.') def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): with torch.cuda.device(input.device): - output = torch.empty( - (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 - ) - grid = lambda META: ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), - ) - matmul_248_kernel[grid]( - input, - qweight, - output, - scales, - qzeros, - g_idx, - input.shape[0], - qweight.shape[1], - input.shape[1], - bits, - maxq, - input.stride(0), - input.stride(1), - qweight.stride(0), - qweight.stride(1), - output.stride(0), - output.stride(1), - scales.stride(0), - qzeros.stride(0), - ) + output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) + matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) return output def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): with torch.cuda.device(input.device): output_dim = (qweight.shape[0] * 32) // bits - output = torch.empty( - (input.shape[0], output_dim), device=input.device, dtype=torch.float16 - ) - grid = lambda META: ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(output_dim, META["BLOCK_SIZE_K"]), - ) - transpose_matmul_248_kernel[grid]( - input, - qweight, - output, - scales, - qzeros, - g_idx, - input.shape[0], - qweight.shape[1], - output_dim, - bits, - maxq, - input.stride(0), - input.stride(1), - qweight.stride(0), - qweight.stride(1), - output.stride(0), - output.stride(1), - scales.stride(0), - qzeros.stride(0), - ) + output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) + transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) return output class QuantLinearFunction(torch.autograd.Function): + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): @@ -475,9 +297,7 @@ class QuantLinearFunction(torch.autograd.Function): grad_input = None if ctx.needs_input_grad[0]: - grad_input = transpose_matmul248( - grad_output, qweight, scales, qzeros, g_idx, bits, maxq - ) + grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) return grad_input, None, None, None, None, None, None @@ -500,39 +320,72 @@ class QuantLinear(nn.Module): @classmethod def new(cls, bits, groupsize, infeatures, outfeatures, bias): - super().__init__() if bits not in [2, 4, 8]: raise NotImplementedError("Only 2,4,8 bits are supported.") - qweight = torch.zeros( - (infeatures // 32 * self.bits, outfeatures), dtype=torch.int32 - ) - qzeros = torch.zeros( - (math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), - dtype=torch.int32, - ) - scales = torch.zeros( - (math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16 - ) - g_idx = torch.tensor( - [i // self.groupsize for i in range(infeatures)], dtype=torch.int32 - ) + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 32 * bits), dtype=torch.int32) + scales = torch.zeros((math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16) + g_idx = torch.tensor([i // groupsize for i in range(infeatures)], dtype=torch.int32) if bias: bias = torch.zeros((outfeatures), dtype=torch.float16) else: bias = None return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures,) - out = QuantLinearFunction.apply( - x.reshape(-1, x.shape[-1]), - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.bits, - self.maxq, - ) + out_shape = x.shape[:-1] + (self.outfeatures, ) + out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index a86d518e..e3fae470 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -4,36 +4,30 @@ import numpy as np import torch import torch.nn as nn import math +import json import os from texttable import Texttable -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer import transformers import numpy as np import torch from text_generation_server.utils.gptq.quant_linear import QuantLinear +from loguru import logger DEV = torch.device("cuda:0") class Quantizer(nn.Module): + def __init__(self, shape=1): super(Quantizer, self).__init__() - self.register_buffer("maxq", torch.tensor(0)) - self.register_buffer("scale", torch.zeros(shape)) - self.register_buffer("zero", torch.zeros(shape)) + self.register_buffer('maxq', torch.tensor(0)) + self.register_buffer('scale', torch.zeros(shape)) + self.register_buffer('zero', torch.zeros(shape)) + + def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False): - def configure( - self, - bits, - perchannel=False, - sym=True, - mse=False, - norm=2.4, - grid=100, - maxshrink=0.8, - trits=False, - ): self.maxq = torch.tensor(2**bits - 1) self.perchannel = perchannel self.sym = sym @@ -94,16 +88,14 @@ class Quantizer(nn.Module): self.zero = torch.round(-xmin / self.scale) if self.mse: - best = torch.full([x.shape[0]], float("inf"), device=dev) + best = torch.full([x.shape[0]], float('inf'), device=dev) for i in range(int(self.maxshrink * self.grid)): p = 1 - i / self.grid xmin1 = p * xmin xmax1 = p * xmax scale1 = (xmax1 - xmin1) / self.maxq zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero - q = self._quantize( - x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq - ) + q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) q -= x q.abs_() q.pow_(self.norm) @@ -150,6 +142,7 @@ class Quantizer(nn.Module): class GPTQ: + def __init__(self, layer, observe=False): self.layer = layer self.dev = self.layer.weight.device @@ -177,19 +170,12 @@ class GPTQ: if len(inp.shape) == 2: inp = inp.unsqueeze(0) tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance( - self.layer, transformers.Conv1D - ): + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() if isinstance(self.layer, nn.Conv2d): - unfold = nn.Unfold( - self.layer.kernel_size, - dilation=self.layer.dilation, - padding=self.layer.padding, - stride=self.layer.stride, - ) + unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride) inp = unfold(inp) inp = inp.permute([1, 0, 2]) inp = inp.flatten(1) @@ -202,14 +188,12 @@ class GPTQ: def print_loss(self, name, q_weight, weight_error, timecost): table = Texttable() - name += " " * (16 - len(name)) + name += ' ' * (16 - len(name)) - table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"]) + table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time']) # assign weight - self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to( - self.layer.weight.data.dtype - ) + self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) if self.inp1 is not None: # quantize input to int8 @@ -223,15 +207,13 @@ class GPTQ: q_SNR = torch_snr_error(q_out, self.out1).item() fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() else: - q_SNR = "-" - fp_SNR = "-" + q_SNR = '-' + fp_SNR = '-' table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) - print(table.draw().split("\n")[-2]) + print(table.draw().split('\n')[-2]) - def fasterquant( - self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name="" - ): + def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''): self.layer.to(self.dev) W = self.layer.weight.data.clone() @@ -290,9 +272,7 @@ class GPTQ: if groupsize != -1: if (i1 + i) % groupsize == 0: - self.quantizer.find_params( - W[:, (i1 + i) : (i1 + i + groupsize)], weight=True - ) + self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) if ((i1 + i) // groupsize) - now_idx == -1: scale.append(self.quantizer.scale) @@ -301,7 +281,7 @@ class GPTQ: q = self.quantizer.quantize(w.unsqueeze(1)).flatten() Q1[:, i] = q - Losses1[:, i] = (w - q) ** 2 / d**2 + Losses1[:, i] = (w - q)**2 / d**2 err1 = (w - q) / d W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) @@ -326,9 +306,7 @@ class GPTQ: if isinstance(self.layer, transformers.Conv1D): Q = Q.t() - self.print_loss( - name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick) - ) + self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)) if scale == []: scale.append(self.quantizer.scale) @@ -348,18 +326,15 @@ class GPTQ: def get_wikitext2(nsamples, seed, seqlen, model_id): from datasets import load_dataset - - traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") - testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) - trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") - testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") + trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') + testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') import random - random.seed(seed) trainloader = [] for _ in range(nsamples): @@ -374,21 +349,18 @@ def get_wikitext2(nsamples, seed, seqlen, model_id): def get_ptb(nsamples, seed, seqlen, model_id): from datasets import load_dataset - - traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") - valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation") + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') from transformers import AutoTokenizer - try: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) except: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) - trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt") - testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt") + trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') + testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') import random - random.seed(seed) trainloader = [] for _ in range(nsamples): @@ -403,37 +375,22 @@ def get_ptb(nsamples, seed, seqlen, model_id): def get_c4(nsamples, seed, seqlen, model_id): from datasets import load_dataset - - traindata = load_dataset( - "allenai/c4", - "allenai--c4", - data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, - split="train", - use_auth_token=False, - ) - valdata = load_dataset( - "allenai/c4", - "allenai--c4", - data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, - split="validation", - use_auth_token=False, - ) + traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False) + valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False) from transformers import AutoTokenizer - try: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) except: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) import random - random.seed(seed) trainloader = [] for _ in range(nsamples): while True: i = random.randint(0, len(traindata) - 1) - trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') if trainenc.input_ids.shape[1] >= seqlen: break i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) @@ -444,13 +401,12 @@ def get_c4(nsamples, seed, seqlen, model_id): trainloader.append((inp, tar)) import random - random.seed(0) valenc = [] for _ in range(256): while True: i = random.randint(0, len(valdata) - 1) - tmp = tokenizer(valdata[i]["text"], return_tensors="pt") + tmp = tokenizer(valdata[i]['text'], return_tensors='pt') if tmp.input_ids.shape[1] >= seqlen: break i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) @@ -459,6 +415,7 @@ def get_c4(nsamples, seed, seqlen, model_id): valenc = torch.hstack(valenc) class TokenizerWrapper: + def __init__(self, input_ids): self.input_ids = input_ids @@ -469,21 +426,18 @@ def get_c4(nsamples, seed, seqlen, model_id): def get_ptb_new(nsamples, seed, seqlen, model_id): from datasets import load_dataset - - traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") - testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') from transformers import AutoTokenizer - try: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) except: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) - trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt") - testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt") + trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') + testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') import random - random.seed(seed) trainloader = [] for _ in range(nsamples): @@ -498,35 +452,22 @@ def get_ptb_new(nsamples, seed, seqlen, model_id): def get_c4_new(nsamples, seed, seqlen, model_id): from datasets import load_dataset - - traindata = load_dataset( - "allenai/c4", - "allenai--c4", - data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, - split="train", - ) - valdata = load_dataset( - "allenai/c4", - "allenai--c4", - data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, - split="validation", - ) + traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') + valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') from transformers import AutoTokenizer - try: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) except: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) import random - random.seed(seed) trainloader = [] for _ in range(nsamples): while True: i = random.randint(0, len(traindata) - 1) - trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') if trainenc.input_ids.shape[1] >= seqlen: break i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) @@ -536,10 +477,11 @@ def get_c4_new(nsamples, seed, seqlen, model_id): tar[:, :-1] = -100 trainloader.append((inp, tar)) - valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt") - valenc = valenc.input_ids[:, : (256 * seqlen)] + valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') + valenc = valenc.input_ids[:, :(256 * seqlen)] class TokenizerWrapper: + def __init__(self, input_ids): self.input_ids = input_ids @@ -548,46 +490,31 @@ def get_c4_new(nsamples, seed, seqlen, model_id): return trainloader, valenc -def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""): - if "wikitext2" in name: +def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=''): + if 'wikitext2' in name: return get_wikitext2(nsamples, seed, seqlen, model_id) - if "ptb" in name: - if "new" in name: + if 'ptb' in name: + if 'new' in name: return get_ptb_new(nsamples, seed, seqlen, model_id) return get_ptb(nsamples, seed, seqlen, model_id) - if "c4" in name: - if "new" in name: + if 'c4' in name: + if 'new' in name: return get_c4_new(nsamples, seed, seqlen, model_id) return get_c4(nsamples, seed, seqlen, model_id) -def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): # Skip last lm_head linear if type(module) in layers and "lm_head" not in name: return {name: module} res = {} for name1, child in module.named_children(): - res.update( - find_layers( - child, layers=layers, name=name + "." + name1 if name != "" else name1 - ) - ) + res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) return res - @torch.no_grad() -def sequential( - model, - dataloader, - dev, - nsamples, - bits, - groupsize, - percdamp=0.01, - sym: bool = False, - act_order: bool = False, -): - print("Starting ...") +def sequential(model, dataloader, dev, nsamples, bits, groupsize, percdamp=0.01, sym: bool=False, act_order: bool = False): + print('Starting ...') use_cache = model.config.use_cache model.config.use_cache = False @@ -601,21 +528,20 @@ def sequential( # layers[0] = layers[0].to(dev) dtype = next(iter(model.parameters())).dtype - inps = torch.zeros( - (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev - ) - cache = {"i": 0, "attention_mask": None} + inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) + cache = {'i': 0, 'attention_mask': None} class Catcher(nn.Module): + def __init__(self, module): super().__init__() self.module = module def forward(self, inp, **kwargs): - inps[cache["i"]] = inp - cache["i"] += 1 - cache["attention_mask"] = kwargs["attention_mask"] - cache["position_ids"] = kwargs["position_ids"] + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + cache['position_ids'] = kwargs['position_ids'] raise ValueError layers[0] = Catcher(layers[0]) @@ -632,20 +558,19 @@ def sequential( torch.cuda.empty_cache() outs = torch.zeros_like(inps) - attention_mask = cache["attention_mask"].to(dev) - position_ids = cache["position_ids"].to(dev) + attention_mask = cache['attention_mask'].to(dev) + position_ids = cache['position_ids'].to(dev) - print("Ready.") + print('Ready.') quantizers = {} for i in range(len(layers)): - print(f"Quantizing layer {i+1}/{len(layers)}..") - print("+------------------+--------------+------------+-----------+-------+") - print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") - print("+==================+==============+============+===========+=======+") + print(f'Quantizing layer {i+1}/{len(layers)}..') + print('+------------------+--------------+------------+-----------+-------+') + print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |') + print('+==================+==============+============+===========+=======+') from accelerate.hooks import remove_hook_from_submodules - layer = layers[i].to(dev) remove_hook_from_submodules(layer) full = find_layers(layer) @@ -656,11 +581,10 @@ def sequential( gptq = {} for name in subset: gptq[name] = GPTQ(subset[name]) - gptq[name].quantizer.configure( - bits, perchannel=True, sym=sym, mse=False - ) + gptq[name].quantizer.configure(bits, perchannel=True, sym=sym, mse=False) def add_batch(name): + def tmp(_, inp, out): gptq[name].add_batch(inp[0].data, out.data) @@ -670,38 +594,19 @@ def sequential( for name in subset: handles.append(subset[name].register_forward_hook(add_batch(name))) for j in range(nsamples): - outs[j] = layer( - inps[j].unsqueeze(0), - attention_mask=attention_mask, - position_ids=position_ids, - )[0] + + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] for h in handles: h.remove() for name in subset: - scale, zero, g_idx, error = gptq[name].fasterquant( - percdamp=percdamp, - groupsize=groupsize, - actorder=act_order, - name=name, - ) - quantizers["model.layers.%d.%s" % (i, name)] = ( - gptq[name].quantizer.cpu(), - scale.cpu(), - zero.cpu(), - g_idx.cpu(), - bits, - groupsize, - ) + scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=percdamp, groupsize=groupsize, actorder=act_order, name=name) + quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), bits, groupsize) gptq[name].free() for j in range(nsamples): - outs[j] = layer( - inps[j].unsqueeze(0), - attention_mask=attention_mask, - position_ids=position_ids, - )[0] + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] layers[i] = layer.cpu() del layer @@ -709,8 +614,8 @@ def sequential( torch.cuda.empty_cache() inps, outs = outs, inps - print("+------------------+--------------+------------+-----------+-------+") - print("\n") + print('+------------------+--------------+------------+-----------+-------+') + print('\n') # if args.observe: # observer.print() @@ -754,34 +659,34 @@ def sequential( # @torch.no_grad() # def llama_eval(model, testenc, dev): # print('Evaluating ...') -# +# # testenc = testenc.input_ids # nsamples = testenc.numel() // model.seqlen -# +# # use_cache = model.config.use_cache # model.config.use_cache = False # layers = model.model.layers -# +# # model.model.embed_tokens = model.model.embed_tokens.to(dev) # layers[0] = layers[0].to(dev) -# +# # dtype = next(iter(model.parameters())).dtype # inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) # cache = {'i': 0, 'attention_mask': None} -# +# # class Catcher(nn.Module): -# +# # def __init__(self, module): # super().__init__() # self.module = module -# +# # def forward(self, inp, **kwargs): # inps[cache['i']] = inp # cache['i'] += 1 # cache['attention_mask'] = kwargs['attention_mask'] # cache['position_ids'] = kwargs['position_ids'] # raise ValueError -# +# # layers[0] = Catcher(layers[0]) # for i in range(nsamples): # batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) @@ -790,19 +695,19 @@ def sequential( # except ValueError: # pass # layers[0] = layers[0].module -# +# # layers[0] = layers[0].cpu() # model.model.embed_tokens = model.model.embed_tokens.cpu() # torch.cuda.empty_cache() -# +# # outs = torch.zeros_like(inps) # attention_mask = cache['attention_mask'] # position_ids = cache['position_ids'] -# +# # for i in range(len(layers)): # print(i) # layer = layers[i].to(dev) -# +# # if args.nearest: # subset = find_layers(layer) # for name in subset: @@ -811,18 +716,18 @@ def sequential( # W = subset[name].weight.data # quantizer.find_params(W, weight=True) # subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype) -# +# # for j in range(nsamples): # outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] # layers[i] = layer.cpu() # del layer # torch.cuda.empty_cache() # inps, outs = outs, inps -# +# # if model.model.norm is not None: # model.model.norm = model.model.norm.to(dev) # model.lm_head = model.lm_head.to(dev) -# +# # testenc = testenc.to(dev) # nlls = [] # for i in range(nsamples): @@ -838,33 +743,21 @@ def sequential( # nlls.append(neg_log_likelihood) # ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) # print(ppl.item()) -# +# # model.config.use_cache = use_cache - -def make_quant_linear(module, names, bits, groupsize, name=""): +def make_quant_linear(module, names, bits, groupsize, name=''): if isinstance(module, QuantLinear): return for attr in dir(module): tmp = getattr(module, attr) - name1 = name + "." + attr if name != "" else attr + name1 = name + '.' + attr if name != '' else attr if name1 in names: delattr(module, attr) - setattr( - module, - attr, - QuantLinear.new( - bits, - groupsize, - tmp.in_features, - tmp.out_features, - tmp.bias is not None, - ), - ) + setattr(module, attr, QuantLinear.new(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) for name1, child in module.named_children(): - make_quant_linear( - child, names, bits, groupsize, name + "." + name1 if name != "" else name1 - ) + make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) + # TODO: perform packing on GPU @@ -873,26 +766,26 @@ def pack(model, quantizers, bits, groupsize): layers = {n: layers[n] for n in quantizers} make_quant_linear(model, quantizers, bits, groupsize) qlayers = find_layers(model, [QuantLinear]) - print("Packing ...") + print('Packing ...') for name in qlayers: print(name) quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] qlayers[name].pack(layers[name], scale, zero, g_idx) - print("Done.") + print('Done.') return model # def load_quant(model, checkpoint, bits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): # from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils # config = LlamaConfig.from_pretrained(model) -# +# # def noop(*args, **kwargs): # pass -# +# # torch.nn.init.kaiming_uniform_ = noop # torch.nn.init.uniform_ = noop # torch.nn.init.normal_ = noop -# +# # torch.set_default_dtype(torch.half) # modeling_utils._init_weights = False # torch.set_default_dtype(torch.half) @@ -905,29 +798,29 @@ def pack(model, quantizers, bits, groupsize): # if name in layers: # del layers[name] # quant.make_quant_linear(model, layers, bits, groupsize) -# +# # del layers -# +# # print('Loading model ...') # if checkpoint.endswith('.safetensors'): # from safetensors.torch import load_file as safe_load # model.load_state_dict(safe_load(checkpoint)) # else: # model.load_state_dict(torch.load(checkpoint)) -# +# # if eval: # quant.make_quant_attn(model) # quant.make_quant_norm(model) # if fused_mlp: # quant.make_fused_mlp(model) -# +# # if warmup_autotune: # quant.autotune_warmup_linear(model, transpose=not (eval)) # if eval and fused_mlp: # quant.autotune_warmup_fused(model) # model.seqlen = 2048 # print('Done.') -# +# # return model @@ -937,33 +830,33 @@ def pack(model, quantizers, bits, groupsize): # model.model.norm = model.model.norm.to(gpus[0]) # import copy # model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0]) -# +# # cache = {'mask': None, 'position_ids': None} -# +# # class MoveModule(nn.Module): -# +# # def __init__(self, module, invalidate_cache): # super().__init__() # self.module = module # self.dev = next(iter(self.module.parameters())).device # self.invalidate_cache=invalidate_cache -# +# # def forward(self, *inp, **kwargs): # inp = list(inp) # if inp[0].device != self.dev: # inp[0] = inp[0].to(self.dev) -# +# # if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache: # cache['mask'] = kwargs['attention_mask'].to(self.dev) # kwargs['attention_mask'] = cache['mask'] -# +# # if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache: # cache['position_ids'] = kwargs['position_ids'].to(self.dev) # kwargs['position_ids'] = cache['position_ids'] -# +# # tmp = self.module(*inp, **kwargs) # return tmp -# +# # layers = model.model.layers # from math import ceil # if not gpu_dist: @@ -975,49 +868,49 @@ def pack(model, quantizers, bits, groupsize): # assigned_gpus = [0] * (gpu_dist[0]-1) # for i in range(1, len(gpu_dist)): # assigned_gpus = assigned_gpus + [i] * gpu_dist[i] -# +# # remaining_assignments = len(layers)-len(assigned_gpus) - 1 # if remaining_assignments > 0: # assigned_gpus = assigned_gpus + [-1] * remaining_assignments -# +# # assigned_gpus = assigned_gpus + [0] -# +# # for i in range(len(layers)): # layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0) -# +# # model.gpus = gpus -# -# +# +# # def benchmark(model, input_ids, check=False): # input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) # torch.cuda.synchronize() -# +# # cache = {'past': None} -# +# # def clear_past(i): -# +# # def tmp(layer, inp, out): # if cache['past']: # cache['past'][i] = None -# +# # return tmp -# +# # for i, layer in enumerate(model.model.layers): # layer.register_forward_hook(clear_past(i)) -# +# # print('Benchmarking ...') -# +# # if check: # loss = nn.CrossEntropyLoss() # tot = 0. -# +# # def sync(): # if hasattr(model, 'gpus'): # for gpu in model.gpus: # torch.cuda.synchronize(gpu) # else: # torch.cuda.synchronize() -# +# # max_memory = 0 # with torch.no_grad(): # attention_mask = torch.ones((1, input_ids.numel()), device=DEV) @@ -1046,9 +939,7 @@ def pack(model, quantizers, bits, groupsize): def quantize(model_id: str, bits: int, groupsize: int, output_dir: str): print("loading model") - model = AutoModelForCausalLM.from_pretrained( - model_id, torch_dtype=torch.float16, device_map="balanced_low_0" - ) + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0") print("LOADED model") model.seqlen = 2048 @@ -1056,9 +947,8 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str): nsamples = 128 seed = None - dataloader, testloader = get_loaders( - dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen - ) + + dataloader, testloader = get_loaders(dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen) tick = time.time() quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize) @@ -1082,7 +972,7 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str): # dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) # print(dataset) # llama_eval(model, testloader, DEV) - # + # # if args.test_generation: # gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] # if len(gpus) > 1: @@ -1096,7 +986,8 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str): # streamer = TextStreamer(tokenizer) # with torch.no_grad(): # generated_ids = model.generate(input_ids, streamer=streamer) - # + # + # if args.quant_directory is not None: # export_quant_table(quantizers, args.quant_directory) @@ -1109,32 +1000,22 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str): pack(model, quantizers, bits, groupsize) from safetensors.torch import save_file from transformers.modeling_utils import shard_checkpoint - state_dict = model.state_dict() state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict["gptq_bits"] = torch.LongTensor(bits) state_dict["gptq_groupsize"] = torch.LongTensor(groupsize) - shards, index = shard_checkpoint( - state_dict, max_shard_size="10GB", weights_name="model.safetensors" - ) + max_shard_size = "10GB" + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors") os.makedirs(output_dir, exist_ok=True) for shard_file, shard in shards.items(): - save_file( - shard, - os.path.join(output_dir, shard_file), - metadata={ - "format": "pt", - "quantized": "gptq", - "origin": "text-generation-inference", - }, - ) + save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt", "quantized": "gptq", "origin": "text-generation-inference"}) if index is None: - path_to_weights = os.path.join(save_directory, "model.safetensors") + path_to_weights = os.path.join(output_dir, "model.safetensors") logger.info(f"Model weights saved in {path_to_weights}") else: save_index_file = "model.safetensors.index.json" - save_index_file = os.path.join(save_directory, save_index_file) + save_index_file = os.path.join(output_dir, save_index_file) with open(save_index_file, "w", encoding="utf-8") as f: content = json.dumps(index, indent=2, sort_keys=True) + "\n" f.write(content) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index e85e8f2f..b866f091 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -134,15 +134,13 @@ def get_linear(weight, bias, quantize): try: qweight, qzeros, scales, g_idx, bits, groupsize = weight except Exception: - raise NotImplementedError( - f"The passed weight is not `gptq` compatible, loader needs to be updated." - ) + raise NotImplementedError(f"The passed weight is not `gptq` compatible, loader needs to be updated.") linear = QuantLinear( qweight, qzeros, scales, - g_idx, + g_idx, bias, bits, groupsize, @@ -223,7 +221,7 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): - weight = weights.get_multi_weight_col(prefixes, quantize=config.quantize) + weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize) if bias: b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] @@ -241,7 +239,7 @@ class TensorParallelRowLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weight_row(prefix, quantize=config.quantize) + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index bc3e284c..be47b15b 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -86,20 +86,12 @@ class Weights: def get_multi_weights_col(self, prefixes: List[str], quantize: str): if quantize == "gptq": try: - qweight = torch.cat( - [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 - ) + qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) + raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") - qzeros = torch.cat( - [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - scales = torch.cat( - [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) + qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) + scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) @@ -110,17 +102,15 @@ class Weights: weight = (qweight, qzeros, scales, g_idx, bits, groupsize) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - weight = torch.cat(w, dim=dim) + weight = torch.cat(w, dim=1) return weight - def get_multi_self_row(self, prefix: str, quantize: str): + def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) + raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)