mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-02 23:42:06 +00:00
fix gptq issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
9914ffe1f1
commit
8d221b7b79
@ -99,6 +99,8 @@ def serve(
|
||||
"bitsandbytes",
|
||||
"bitsandbytes-nf4",
|
||||
"bitsandbytes-fp4",
|
||||
"gptq",
|
||||
"awq",
|
||||
}:
|
||||
raise RuntimeError(
|
||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||
|
@ -50,6 +50,9 @@ class QuantLinear(nn.Module):
|
||||
|
||||
self.outfeatures = qweight.shape[1]
|
||||
self.infeatures = qweight.shape[0] * 32 // bits
|
||||
self.wf = torch.tensor(
|
||||
list(range(0, 32, self.bits)), dtype=torch.int32
|
||||
).unsqueeze(0)
|
||||
self._preprocessing()
|
||||
|
||||
def unpack_zeros_from_cuda_old_format(self):
|
||||
@ -75,22 +78,24 @@ class QuantLinear(nn.Module):
|
||||
return weight
|
||||
|
||||
def _preprocessing(self):
|
||||
orig_device = self.qweight.device
|
||||
self.qweight = self.qweight.cpu()
|
||||
weight = self.unpack_weight_from_cuda_old_format()
|
||||
new_qweight = pack_tensor(weight)
|
||||
self.qweight = new_qweight.to("hpu")
|
||||
|
||||
self.qweight = new_qweight.to(orig_device)
|
||||
# TODO: Support group indexing and remove the check
|
||||
columns = self.qweight.shape[0]
|
||||
g_idx_trivial = [i // self.group_size for i in range(columns)]
|
||||
g_idx_trivial = torch.tensor(g_idx_trivial, dtype=torch.int32)
|
||||
g_idx_trivial = [i // self.groupsize for i in range(columns)]
|
||||
g_idx_trivial = torch.tensor(
|
||||
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
|
||||
)
|
||||
assert torch.equal(
|
||||
self.g_idx, g_idx_trivial
|
||||
), "Non-trivial tensor g_idx is not supported"
|
||||
|
||||
zeros = self.unpack_zeros_from_cuda_old_format().cpu()
|
||||
self.qzeros = self.qzeros.cpu()
|
||||
zeros = self.unpack_zeros_from_cuda_old_format()
|
||||
new_qzeros = pack_tensor(zeros)
|
||||
self.qzeros = new_qzeros.to("hpu")
|
||||
self.qzeros = new_qzeros.to(orig_device)
|
||||
|
||||
@classmethod
|
||||
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||
|
Loading…
Reference in New Issue
Block a user