From 08b8eec1d749f402d5d560d3c4cc09916320927b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 20 Jul 2023 16:04:15 +0200 Subject: [PATCH] fix(server): Fixing non parameters in quantize script `bigcode/starcoder` was an example. (#661) --- server/text_generation_server/utils/gptq/quantize.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index bee1e446..160c9c9f 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -812,10 +812,13 @@ def load_weights_pre_hook(module_name, weights, recursive=False): tensor = weights.get_tensor(tensor_name) setdeepattr(module, local_param, nn.Parameter(tensor)) else: + tensor = current_tensor.to(device=torch.device("cuda:0")) + if current_tensor.requires_grad: + tensor = nn.Parameter(tensor) setdeepattr( module, local_param, - nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))), + tensor ) return inner