some more cleanup

This commit is contained in:
Felix Marty 2023-07-05 16:42:13 +00:00
parent c858d791e5
commit 2272b3a456
2 changed files with 2 additions and 5 deletions

View File

@ -386,7 +386,7 @@ def ext_q4_matmul(x, q4, q4_width):
class Ex4bitLinear:
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize, device, world_size: int):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize, device):
assert bits == 4
self.device = device
@ -417,9 +417,6 @@ class Ex4bitLinear:
# Infer groupsize from height of qzeros
self.groupsize = None
if self.qzeros.shape[0] > 1:
if world_size is None:
world_size = 1
# self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0] // world_size)
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
assert groupsize == self.groupsize

View File

@ -156,7 +156,7 @@ def get_linear(weight, bias, quantize, device = None):
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize, device, world_size)
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize, device)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear