From 929e374753d4525a2f8dbd5b9fa4f78038d7b52a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 20 Jul 2023 11:16:34 +0000 Subject: [PATCH] Fixing quantize script on models with non parameters buffers. --- server/text_generation_server/utils/gptq/quantize.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index d182456f..8f4fb93d 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -812,10 +812,13 @@ def load_weights_pre_hook(module_name, weights, recursive=False): tensor = weights.get_tensor(tensor_name) setdeepattr(module, local_param, nn.Parameter(tensor)) else: + tensor = current_tensor.to(device=torch.device("cuda:0")) + if current_tensor.requires_grad: + tensor = nn.Parameter(tensor) setdeepattr( module, local_param, - nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))), + tensor ) return inner