mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
typing
This commit is contained in:
parent
8074c40473
commit
646ab28285
@ -154,17 +154,18 @@ class Weights:
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
bits, groupsize, _, quant_method = self._get_gptq_params()
|
||||
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||
scales = self._get_qweight(f"{prefix}.scales")
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
if quantize == "gptq" and self.quant_method == "gptq":
|
||||
|
||||
if quantize == "gptq" and quant_method == "gptq":
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
bits, groupsize, _, _ = self._get_gptq_params()
|
||||
|
||||
if quantize == "gptq" and self.quant_method == "awq":
|
||||
if quantize == "gptq" and quant_method == "awq":
|
||||
log_once(
|
||||
logger.info,
|
||||
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
|
||||
@ -212,7 +213,9 @@ class Weights:
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
if quantize == "gptq" and self.quant_method == "gptq":
|
||||
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||
|
||||
if quantize == "gptq" and quant_method == "gptq":
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
@ -220,7 +223,6 @@ class Weights:
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
|
||||
use_exllama = (
|
||||
@ -347,7 +349,7 @@ class Weights:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def _get_gptq_params(self) -> Tuple[int, int, int]:
|
||||
def _get_gptq_params(self) -> Tuple[int, int, int, str]:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
|
Loading…
Reference in New Issue
Block a user