diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index bee1e446..160c9c9f 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -812,10 +812,13 @@ def load_weights_pre_hook(module_name, weights, recursive=False): tensor = weights.get_tensor(tensor_name) setdeepattr(module, local_param, nn.Parameter(tensor)) else: + tensor = current_tensor.to(device=torch.device("cuda:0")) + if current_tensor.requires_grad: + tensor = nn.Parameter(tensor) setdeepattr( module, local_param, - nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))), + tensor ) return inner