Bug fixes for GPTQ_BITS env var passthrough

This commit is contained in:
ssmi153 2023-07-12 17:25:24 +08:00
parent 7f9072228a
commit 636a4cca85

View File

@ -127,8 +127,8 @@ class Weights:
try:
import os
bits = int(os.getenv("GTPQ_BITS"))
groupsize = int(os.getenv("GTPQ_GROUPSIZE"))
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
@ -149,8 +149,17 @@ class Weights:
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
else: