diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index 6b6d5cb4..1b807427 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -6,6 +6,8 @@ from torch.cuda.amp import custom_bwd, custom_fwd import torch +from loguru import logger + try: from custom_kernels.exllama import make_q4, q4_matmul except Exception as e: @@ -422,8 +424,9 @@ class Ex4bitLinear: self.groupsize = None if self.qzeros.shape[0] > 1: self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) - - assert groupsize == self.groupsize + + if self.groupsize is not None: + assert groupsize == self.groupsize # Handle act-order matrix if self.g_idx is not None: diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index b4b3bfb5..5bbf04d0 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -152,6 +152,8 @@ class Weights: except RuntimeError: raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + bits, groupsize = self.get_gptq_qparams() + if use_triton_kernel: # The triton kernel reorders the scales/zero points instead of the weight/activation. # Thus, each rank needs the full qzeros/scales. @@ -159,10 +161,14 @@ class Weights: scales = self.get_tensor(f"{prefix}.scales") g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) else: - # Exllama reorders the weights in advance and the activations on the fly, thus - # the scales and zero-points do not need to be reordered - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) + if groupsize >= 16: + # Exllama reorders the weights in advance and the activations on the fly, thus + # the scales and zero-points do not need to be reordered. + qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) + scales = self.get_sharded(f"{prefix}.scales", dim=0) + else: + qzeros = self.get_tensor(f"{prefix}.qzeros") + scales = self.get_tensor(f"{prefix}.scales") # For tp > 1, at this point we know we do not use act-order if self.process_group.size() == 1: @@ -170,8 +176,6 @@ class Weights: else: g_idx = None - bits, groupsize = self.get_gptq_qparams() - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel) else: weight = self.get_sharded(f"{prefix}.weight", dim=1)