fix per-column quantization

This commit is contained in:
Felix Marty 2023-07-19 17:55:41 +00:00
parent edfbfdfb3f
commit 6bf7090ecd
2 changed files with 15 additions and 8 deletions

View File

@ -6,6 +6,8 @@ from torch.cuda.amp import custom_bwd, custom_fwd
import torch import torch
from loguru import logger
try: try:
from custom_kernels.exllama import make_q4, q4_matmul from custom_kernels.exllama import make_q4, q4_matmul
except Exception as e: except Exception as e:
@ -423,6 +425,7 @@ class Ex4bitLinear:
if self.qzeros.shape[0] > 1: if self.qzeros.shape[0] > 1:
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
if self.groupsize is not None:
assert groupsize == self.groupsize assert groupsize == self.groupsize
# Handle act-order matrix # Handle act-order matrix

View File

@ -152,6 +152,8 @@ class Weights:
except RuntimeError: except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
bits, groupsize = self.get_gptq_qparams()
if use_triton_kernel: if use_triton_kernel:
# The triton kernel reorders the scales/zero points instead of the weight/activation. # The triton kernel reorders the scales/zero points instead of the weight/activation.
# Thus, each rank needs the full qzeros/scales. # Thus, each rank needs the full qzeros/scales.
@ -159,10 +161,14 @@ class Weights:
scales = self.get_tensor(f"{prefix}.scales") scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
else: else:
if groupsize >= 16:
# Exllama reorders the weights in advance and the activations on the fly, thus # Exllama reorders the weights in advance and the activations on the fly, thus
# the scales and zero-points do not need to be reordered # the scales and zero-points do not need to be reordered.
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
# For tp > 1, at this point we know we do not use act-order # For tp > 1, at this point we know we do not use act-order
if self.process_group.size() == 1: if self.process_group.size() == 1:
@ -170,8 +176,6 @@ class Weights:
else: else:
g_idx = None g_idx = None
bits, groupsize = self.get_gptq_qparams()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
else: else:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)