[Bug Fix] Update import in quantization layers from nn to torch.nn based on import statements in the file header

This commit is contained in:
Dhruv Srikanth 2024-05-15 17:03:00 +01:00 committed by GitHub
parent a69ef52cf6
commit db7190d609
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -70,7 +70,7 @@ class Linear8bitLt(torch.nn.Module):
return out return out
class Linear4bit(nn.Module): class Linear4bit(torch.nn.Module):
def __init__(self, weight, bias, quant_type): def __init__(self, weight, bias, quant_type):
super().__init__() super().__init__()
self.weight = Params4bit( self.weight = Params4bit(