diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 39f66862c..4f300fe76 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -127,8 +127,8 @@ class Weights: try: import os - bits = int(os.getenv("GTPQ_BITS")) - groupsize = int(os.getenv("GTPQ_GROUPSIZE")) + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) except Exception: raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) @@ -149,8 +149,17 @@ class Weights: scales = self.get_tensor(f"{prefix}.scales") g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) else: