diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index f600a296..767a23b2 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -163,20 +163,17 @@ class Weights: g_idx = self.get_tensor(f"{prefix}.g_idx") elif quantize == "gptq" and quant_method == "awq": log_once( - logger.info, - "Converting AWQ weights to Exllama/GPTQ packing format, " - "in order used with Exllama/GPTQ kernels.", + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) from text_generation_server.utils.awq.conversion_utils import ( fast_awq_to_gptq, ) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = torch.zeros( - (qweight.shape[0] * 32 // bits), - dtype=torch.int32, - device=qweight.device, - ) + g_idx = ( + torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) + // groupsize + ).to(dtype=torch.int32) else: g_idx = None @@ -230,20 +227,17 @@ class Weights: g_idx = w[0] elif quantize == "gptq" and quant_method == "awq": log_once( - logger.info, - "Converting AWQ weights to Exllama/GPTQ packing format, " - "in order used with Exllama/GPTQ kernels.", + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) from text_generation_server.utils.awq.conversion_utils import ( fast_awq_to_gptq, ) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = torch.zeros( - (qweight.shape[0] * 32 // bits), - dtype=torch.int32, - device=qweight.device, - ) + g_idx = ( + torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) + // groupsize + ).to(dtype=torch.int32) else: g_idx = None @@ -340,20 +334,17 @@ class Weights: if quant_method == "awq": log_once( - logger.info, - "Converting AWQ weights to Exllama/GPTQ packing format, " - "in order used with Exllama/GPTQ kernels.", + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) from text_generation_server.utils.awq.conversion_utils import ( fast_awq_to_gptq, ) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = torch.zeros( - (qweight.shape[0] * 32 // bits), - dtype=torch.int32, - device=qweight.device, - ) + g_idx = ( + torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) + // groupsize + ).to(dtype=torch.int32) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq":