Fixing register bias + gptq_bits type.

This commit is contained in:
Ubuntu 2023-06-13 21:45:23 +00:00 committed by Nicolas Patry
parent ffe8fc4699
commit ee1f94e64b
2 changed files with 6 additions and 6 deletions

View File

@ -304,12 +304,12 @@ class QuantLinearFunction(torch.autograd.Function):
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.qweight = self.register_buffer("qweight", qweight)
self.qzeros = self.register_buffer("qzeros", qzeros)
self.scales = self.register_buffer("scales", scales)
self.g_idx = self.register_buffer("g_idx", g_idx)
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.bias = self.register_buffer("bias", bias)
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [2, 4, 8]:

View File

@ -49,7 +49,7 @@ class Weights:
tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype != torch.int32:
if tensor.dtype not in [torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor