Fixing quantize script on models with non parameters buffers.

This commit is contained in:
Nicolas Patry 2023-07-20 11:16:34 +00:00
parent fe80f5360c
commit 929e374753

View File

@ -812,10 +812,13 @@ def load_weights_pre_hook(module_name, weights, recursive=False):
tensor = weights.get_tensor(tensor_name)
setdeepattr(module, local_param, nn.Parameter(tensor))
else:
tensor = current_tensor.to(device=torch.device("cuda:0"))
if current_tensor.requires_grad:
tensor = nn.Parameter(tensor)
setdeepattr(
module,
local_param,
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))),
tensor
)
return inner