Some gptq case could not be handled by ipex. but could be handle by triton (#3298)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-08-19 15:37:49 +08:00 committed by GitHub
parent 5284b5c654
commit 6624fec1f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 6 deletions

View File

@ -12,11 +12,7 @@ from text_generation_server.utils.weights import (
WeightsLoader,
DefaultWeightsLoader,
)
if SYSTEM == "ipex":
from .ipex import QuantLinear
elif SYSTEM in {"cuda", "rocm"}:
from .triton import QuantLinear
import math
@dataclass
@ -70,6 +66,19 @@ class GPTQWeight(Weight):
return ExllamaQuantLinear(self, bias)
else:
if SYSTEM == "ipex" and not (
self.device.type == "xpu"
and (
self.bits != 4
or math.ceil(
(self.qweight.shape[0] * 32 // self.bits) / self.groupsize
)
!= self.scales.shape[0]
)
):
from .ipex import QuantLinear
else:
from .triton import QuantLinear
return QuantLinear(
self.qweight,
self.qzeros,

View File

@ -202,7 +202,11 @@ def matmul_248_kernel(
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
with (
torch.xpu.device(input.device)
if torch.xpu.is_available()
else torch.cuda.device(input.device)
):
output = torch.empty(
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
)