diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 0330402d..dcab6296 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from typing import List, Dict, Optional, Tuple from safetensors import safe_open, SafetensorError @@ -221,8 +222,12 @@ class Weights: return bits, groupsize def _set_gptq_params(self, model_id): + filename = "quantize_config.json" try: - filename = hf_hub_download(model_id, filename="quantize_config.json") + if not os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename) with open(filename, "r") as f: data = json.load(f) self.gptq_bits = data["bits"]