Gaudi gptq gidx support (#3297)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-07-17 22:00:12 +08:00 committed by GitHub
parent fc2405c549
commit 24c2bff659
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 11 deletions

View File

@ -89,13 +89,31 @@ class QuantLinear(nn.Module):
g_idx_trivial = torch.tensor( g_idx_trivial = torch.tensor(
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
) )
assert torch.equal( sort_zeros = not (torch.equal(self.g_idx, g_idx_trivial))
self.g_idx, g_idx_trivial
), "Non-trivial tensor g_idx is not supported"
self.qzeros = self.qzeros.cpu() self.qzeros = self.qzeros.cpu()
zeros = self.unpack_zeros_from_cuda_old_format() zeros = self.unpack_zeros_from_cuda_old_format()
new_qzeros = pack_tensor(zeros) if sort_zeros:
self.qzeros = new_qzeros.to(orig_device) zeros_group_1 = torch.zeros(
(self.infeatures, self.outfeatures),
dtype=zeros.dtype,
device=zeros.device,
)
scales = self.scales.cpu()
scale_group_1 = torch.zeros(
(self.infeatures, self.outfeatures),
dtype=scales.dtype,
device=scales.device,
)
for i in range(self.infeatures):
zeros_group_1[i] = zeros[self.g_idx[i]]
scale_group_1[i] = self.scales[self.g_idx[i]]
self.qzeros = pack_tensor(zeros_group_1).to(orig_device)
self.scales = scale_group_1.to(orig_device)
self.groupsize = 1
self.g_idx = None
else:
new_qzeros = pack_tensor(zeros)
self.qzeros = new_qzeros.to(orig_device)
@classmethod @classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias): def new(cls, bits, groupsize, infeatures, outfeatures, bias):

View File

@ -16,12 +16,15 @@ else:
punica_sgmv = None punica_sgmv = None
if SYSTEM == "ipex": if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.functional import ( try:
bgmv_expand, from intel_extension_for_pytorch.llm.functional import (
bgmv_shrink, bgmv_expand,
sgmv_expand, bgmv_shrink,
sgmv_shrink, sgmv_expand,
) sgmv_shrink,
)
except ImportError:
pass
if TYPE_CHECKING: if TYPE_CHECKING: