From 24c2bff65924801ddf90fa24fcc72752d4f45538 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 17 Jul 2025 22:00:12 +0800 Subject: [PATCH] Gaudi gptq gidx support (#3297) Signed-off-by: Wang, Yi A --- .../text_generation_server/layers/gptq/hpu.py | 28 +++++++++++++++---- server/text_generation_server/layers/lora.py | 15 ++++++---- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py index 72944fa0..fa1d8a2e 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py @@ -89,13 +89,31 @@ class QuantLinear(nn.Module): g_idx_trivial = torch.tensor( g_idx_trivial, dtype=torch.int32, device=self.g_idx.device ) - assert torch.equal( - self.g_idx, g_idx_trivial - ), "Non-trivial tensor g_idx is not supported" + sort_zeros = not (torch.equal(self.g_idx, g_idx_trivial)) self.qzeros = self.qzeros.cpu() zeros = self.unpack_zeros_from_cuda_old_format() - new_qzeros = pack_tensor(zeros) - self.qzeros = new_qzeros.to(orig_device) + if sort_zeros: + zeros_group_1 = torch.zeros( + (self.infeatures, self.outfeatures), + dtype=zeros.dtype, + device=zeros.device, + ) + scales = self.scales.cpu() + scale_group_1 = torch.zeros( + (self.infeatures, self.outfeatures), + dtype=scales.dtype, + device=scales.device, + ) + for i in range(self.infeatures): + zeros_group_1[i] = zeros[self.g_idx[i]] + scale_group_1[i] = self.scales[self.g_idx[i]] + self.qzeros = pack_tensor(zeros_group_1).to(orig_device) + self.scales = scale_group_1.to(orig_device) + self.groupsize = 1 + self.g_idx = None + else: + new_qzeros = pack_tensor(zeros) + self.qzeros = new_qzeros.to(orig_device) @classmethod def new(cls, bits, groupsize, infeatures, outfeatures, bias): diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index daac638c..54cc8bf6 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -16,12 +16,15 @@ else: punica_sgmv = None if SYSTEM == "ipex": - from intel_extension_for_pytorch.llm.functional import ( - bgmv_expand, - bgmv_shrink, - sgmv_expand, - sgmv_shrink, - ) + try: + from intel_extension_for_pytorch.llm.functional import ( + bgmv_expand, + bgmv_shrink, + sgmv_expand, + sgmv_shrink, + ) + except ImportError: + pass if TYPE_CHECKING: