mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
typing
This commit is contained in:
parent
8074c40473
commit
646ab28285
@ -154,17 +154,18 @@ class Weights:
|
|||||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
bits, groupsize, _, quant_method = self._get_gptq_params()
|
||||||
|
|
||||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||||
scales = self._get_qweight(f"{prefix}.scales")
|
scales = self._get_qweight(f"{prefix}.scales")
|
||||||
scales = scales.to(dtype=self.dtype)
|
scales = scales.to(dtype=self.dtype)
|
||||||
if quantize == "gptq" and self.quant_method == "gptq":
|
|
||||||
|
if quantize == "gptq" and quant_method == "gptq":
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
else:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize, _, _ = self._get_gptq_params()
|
if quantize == "gptq" and quant_method == "awq":
|
||||||
|
|
||||||
if quantize == "gptq" and self.quant_method == "awq":
|
|
||||||
log_once(
|
log_once(
|
||||||
logger.info,
|
logger.info,
|
||||||
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
|
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
|
||||||
@ -212,7 +213,9 @@ class Weights:
|
|||||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
if quantize == "gptq" and self.quant_method == "gptq":
|
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||||
|
|
||||||
|
if quantize == "gptq" and quant_method == "gptq":
|
||||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
for w2 in w[1:]:
|
for w2 in w[1:]:
|
||||||
torch.testing.assert_close(w2, w[0])
|
torch.testing.assert_close(w2, w[0])
|
||||||
@ -220,7 +223,6 @@ class Weights:
|
|||||||
else:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
use_exllama = (
|
use_exllama = (
|
||||||
@ -347,7 +349,7 @@ class Weights:
|
|||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def _get_gptq_params(self) -> Tuple[int, int, int]:
|
def _get_gptq_params(self) -> Tuple[int, int, int, str]:
|
||||||
try:
|
try:
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
|
Loading…
Reference in New Issue
Block a user