Fix __call__ vs forward.

This commit is contained in:
Nicolas Patry 2023-09-07 14:02:34 +00:00
parent b03d2621a7
commit 07bc903d6e

View File

@ -69,10 +69,11 @@ def create_exllama_buffers():
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
class Ex4bitLinear:
class Ex4bitLinear(torch.nn.Module):
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
assert bits == 4