Fixing GPTQ exllama kernel usage.

This commit is contained in:
Nicolas Patry 2023-10-04 15:50:56 +00:00
parent 6df43da0a4
commit 2d4ae09074

View File

@ -212,7 +212,9 @@ class Weights:
g_idx = None g_idx = None
bits, groupsize = self._get_gptq_params() bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = bits==4 and HAS_EXLLAMA and quantize == "gptq"
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim) weight = torch.cat(w, dim=dim)