diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index e603e7fc..29505902 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -63,7 +63,7 @@ class FlashSantacoderSharded(FlashCausalLM): aliases={"transformer.wte.weight": ["lm_head.weight"]}, ) if config.quantize == "gptq": - weights.set_gptq_params(model_id) + weights._set_gptq_params(model_id) model = FlashSantacoderForCausalLM(config, weights) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 2e6b619b..3bfbf22c 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -3,6 +3,8 @@ from typing import List, Dict, Optional, Tuple from safetensors import safe_open, SafetensorError import torch from loguru import logger +from huggingface_hub import hf_hub_download +import json class Weights: