mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-19 14:20:19 +00:00
Gaudi gptq gidx support (#3297)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
fc2405c549
commit
24c2bff659
@ -89,13 +89,31 @@ class QuantLinear(nn.Module):
|
|||||||
g_idx_trivial = torch.tensor(
|
g_idx_trivial = torch.tensor(
|
||||||
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
|
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
|
||||||
)
|
)
|
||||||
assert torch.equal(
|
sort_zeros = not (torch.equal(self.g_idx, g_idx_trivial))
|
||||||
self.g_idx, g_idx_trivial
|
|
||||||
), "Non-trivial tensor g_idx is not supported"
|
|
||||||
self.qzeros = self.qzeros.cpu()
|
self.qzeros = self.qzeros.cpu()
|
||||||
zeros = self.unpack_zeros_from_cuda_old_format()
|
zeros = self.unpack_zeros_from_cuda_old_format()
|
||||||
new_qzeros = pack_tensor(zeros)
|
if sort_zeros:
|
||||||
self.qzeros = new_qzeros.to(orig_device)
|
zeros_group_1 = torch.zeros(
|
||||||
|
(self.infeatures, self.outfeatures),
|
||||||
|
dtype=zeros.dtype,
|
||||||
|
device=zeros.device,
|
||||||
|
)
|
||||||
|
scales = self.scales.cpu()
|
||||||
|
scale_group_1 = torch.zeros(
|
||||||
|
(self.infeatures, self.outfeatures),
|
||||||
|
dtype=scales.dtype,
|
||||||
|
device=scales.device,
|
||||||
|
)
|
||||||
|
for i in range(self.infeatures):
|
||||||
|
zeros_group_1[i] = zeros[self.g_idx[i]]
|
||||||
|
scale_group_1[i] = self.scales[self.g_idx[i]]
|
||||||
|
self.qzeros = pack_tensor(zeros_group_1).to(orig_device)
|
||||||
|
self.scales = scale_group_1.to(orig_device)
|
||||||
|
self.groupsize = 1
|
||||||
|
self.g_idx = None
|
||||||
|
else:
|
||||||
|
new_qzeros = pack_tensor(zeros)
|
||||||
|
self.qzeros = new_qzeros.to(orig_device)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||||
|
@ -16,12 +16,15 @@ else:
|
|||||||
punica_sgmv = None
|
punica_sgmv = None
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
from intel_extension_for_pytorch.llm.functional import (
|
try:
|
||||||
bgmv_expand,
|
from intel_extension_for_pytorch.llm.functional import (
|
||||||
bgmv_shrink,
|
bgmv_expand,
|
||||||
sgmv_expand,
|
bgmv_shrink,
|
||||||
sgmv_shrink,
|
sgmv_expand,
|
||||||
)
|
sgmv_shrink,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
Loading…
Reference in New Issue
Block a user