fix gptq issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-22 20:58:37 -07:00
parent 9914ffe1f1
commit 8d221b7b79
2 changed files with 14 additions and 7 deletions

View File

@ -99,6 +99,8 @@ def serve(
"bitsandbytes",
"bitsandbytes-nf4",
"bitsandbytes-fp4",
"gptq",
"awq",
}:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."

View File

@ -50,6 +50,9 @@ class QuantLinear(nn.Module):
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
self.wf = torch.tensor(
list(range(0, 32, self.bits)), dtype=torch.int32
).unsqueeze(0)
self._preprocessing()
def unpack_zeros_from_cuda_old_format(self):
@ -75,22 +78,24 @@ class QuantLinear(nn.Module):
return weight
def _preprocessing(self):
orig_device = self.qweight.device
self.qweight = self.qweight.cpu()
weight = self.unpack_weight_from_cuda_old_format()
new_qweight = pack_tensor(weight)
self.qweight = new_qweight.to("hpu")
self.qweight = new_qweight.to(orig_device)
# TODO: Support group indexing and remove the check
columns = self.qweight.shape[0]
g_idx_trivial = [i // self.group_size for i in range(columns)]
g_idx_trivial = torch.tensor(g_idx_trivial, dtype=torch.int32)
g_idx_trivial = [i // self.groupsize for i in range(columns)]
g_idx_trivial = torch.tensor(
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
)
assert torch.equal(
self.g_idx, g_idx_trivial
), "Non-trivial tensor g_idx is not supported"
zeros = self.unpack_zeros_from_cuda_old_format().cpu()
self.qzeros = self.qzeros.cpu()
zeros = self.unpack_zeros_from_cuda_old_format()
new_qzeros = pack_tensor(zeros)
self.qzeros = new_qzeros.to("hpu")
self.qzeros = new_qzeros.to(orig_device)
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):