Fixing quantize script on models with non parameters buffers.

This commit is contained in:
Nicolas Patry 2023-07-20 11:16:34 +00:00
parent fe80f5360c
commit 929e374753

View File

@ -812,10 +812,13 @@ def load_weights_pre_hook(module_name, weights, recursive=False):
tensor = weights.get_tensor(tensor_name) tensor = weights.get_tensor(tensor_name)
setdeepattr(module, local_param, nn.Parameter(tensor)) setdeepattr(module, local_param, nn.Parameter(tensor))
else: else:
tensor = current_tensor.to(device=torch.device("cuda:0"))
if current_tensor.requires_grad:
tensor = nn.Parameter(tensor)
setdeepattr( setdeepattr(
module, module,
local_param, local_param,
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))), tensor
) )
return inner return inner