diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 444bf7e2..b4b3bfb5 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -176,3 +176,18 @@ class Weights: else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight + + def get_gptq_qparams(self) -> Tuple[int, int]: + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except (SafetensorError, RuntimeError) as e: + try: + import os + + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e + + return bits, groupsize \ No newline at end of file