From ee1f94e64b6b193898ad373df91c414527e5942f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 13 Jun 2023 21:45:23 +0000 Subject: [PATCH] Fixing register bias + gptq_bits type. --- .../text_generation_server/utils/gptq/quant_linear.py | 10 +++++----- server/text_generation_server/utils/weights.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index 738f54b7b..f818cb0ee 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -304,12 +304,12 @@ class QuantLinearFunction(torch.autograd.Function): class QuantLinear(nn.Module): def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): super().__init__() - self.qweight = self.register_buffer("qweight", qweight) - self.qzeros = self.register_buffer("qzeros", qzeros) - self.scales = self.register_buffer("scales", scales) - self.g_idx = self.register_buffer("g_idx", g_idx) + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) if bias is not None: - self.bias = self.register_buffer("bias", bias) + self.register_buffer("bias", bias) else: self.bias = None if bits not in [2, 4, 8]: diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 6d0d7c374..17f6ced5f 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -49,7 +49,7 @@ class Weights: tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert # u4 which are disguised as int32 - if tensor.dtype != torch.int32: + if tensor.dtype not in [torch.int32, torch.int64]: tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor