From d31bdfdbae0d65ef22680bc1ff83973bec86e720 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 16 Jul 2025 22:46:05 -0700 Subject: [PATCH] support g_idx Signed-off-by: Wang, Yi A --- .../text_generation_server/layers/gptq/hpu.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py index 72944fa0..fa1d8a2e 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py @@ -89,13 +89,31 @@ class QuantLinear(nn.Module): g_idx_trivial = torch.tensor( g_idx_trivial, dtype=torch.int32, device=self.g_idx.device ) - assert torch.equal( - self.g_idx, g_idx_trivial - ), "Non-trivial tensor g_idx is not supported" + sort_zeros = not (torch.equal(self.g_idx, g_idx_trivial)) self.qzeros = self.qzeros.cpu() zeros = self.unpack_zeros_from_cuda_old_format() - new_qzeros = pack_tensor(zeros) - self.qzeros = new_qzeros.to(orig_device) + if sort_zeros: + zeros_group_1 = torch.zeros( + (self.infeatures, self.outfeatures), + dtype=zeros.dtype, + device=zeros.device, + ) + scales = self.scales.cpu() + scale_group_1 = torch.zeros( + (self.infeatures, self.outfeatures), + dtype=scales.dtype, + device=scales.device, + ) + for i in range(self.infeatures): + zeros_group_1[i] = zeros[self.g_idx[i]] + scale_group_1[i] = self.scales[self.g_idx[i]] + self.qzeros = pack_tensor(zeros_group_1).to(orig_device) + self.scales = scale_group_1.to(orig_device) + self.groupsize = 1 + self.g_idx = None + else: + new_qzeros = pack_tensor(zeros) + self.qzeros = new_qzeros.to(orig_device) @classmethod def new(cls, bits, groupsize, infeatures, outfeatures, bias):