mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
Fixing register bias + gptq_bits type.
This commit is contained in:
parent
ffe8fc4699
commit
ee1f94e64b
@ -304,12 +304,12 @@ class QuantLinearFunction(torch.autograd.Function):
|
||||
class QuantLinear(nn.Module):
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
super().__init__()
|
||||
self.qweight = self.register_buffer("qweight", qweight)
|
||||
self.qzeros = self.register_buffer("qzeros", qzeros)
|
||||
self.scales = self.register_buffer("scales", scales)
|
||||
self.g_idx = self.register_buffer("g_idx", g_idx)
|
||||
self.register_buffer("qweight", qweight)
|
||||
self.register_buffer("qzeros", qzeros)
|
||||
self.register_buffer("scales", scales)
|
||||
self.register_buffer("g_idx", g_idx)
|
||||
if bias is not None:
|
||||
self.bias = self.register_buffer("bias", bias)
|
||||
self.register_buffer("bias", bias)
|
||||
else:
|
||||
self.bias = None
|
||||
if bits not in [2, 4, 8]:
|
||||
|
@ -49,7 +49,7 @@ class Weights:
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32
|
||||
if tensor.dtype != torch.int32:
|
||||
if tensor.dtype not in [torch.int32, torch.int64]:
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
Loading…
Reference in New Issue
Block a user