From 0b5b85877975da51be7cc2342338609fcb07fa85 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 1 Feb 2024 13:30:43 +0000 Subject: [PATCH] fix missing g_idx and eventual overflow in triton kernel --- .../utils/gptq/quant_linear.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index bfc91c00..34895c01 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -182,7 +182,7 @@ try: ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = zeros + 1 + zeros = (zeros + 1) & maxq # add 1 and avoid overflow a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated @@ -251,7 +251,17 @@ class QuantLinear(nn.Module): self.register_buffer("qweight", qweight) self.register_buffer("qzeros", qzeros) self.register_buffer("scales", scales) - self.register_buffer("g_idx", g_idx) + if g_idx is not None: + self.register_buffer("g_idx", g_idx) + else: + self.register_buffer( + "g_idx", + torch.tensor( + [i // groupsize for i in range(qweight.shape[0] * 32 // bits)], + device=qweight.device, + dtype=torch.int32, + ), + ) if bias is not None: self.register_buffer("bias", bias) else: