diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index c3876023..fed5e6f3 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -81,7 +81,7 @@ class BLOOMSharded(CausalLM): prefix="transformer", ) if config.quantize == "gptq": - weights._set_gptq_params(model_id) + weights._set_gptq_params(model_id, revision) model = BloomForCausalLM(config, weights) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 2415a245..8a3bccdd 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -64,7 +64,7 @@ class FlashLlama(FlashCausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) if config.quantize in ["gptq", "awq"]: - weights._set_gptq_params(model_id) + weights._set_gptq_params(model_id, revision) model = FlashLlamaForCausalLM(config, weights) if use_medusa: diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index abe07c30..8c6cb025 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -328,7 +328,7 @@ class BaseFlashMistral(FlashCausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) if config.quantize in ["gptq", "awq"]: - weights._set_gptq_params(model_id) + weights._set_gptq_params(model_id, revision) model = model_cls(config, weights) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 58f345a9..80f8804d 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -53,7 +53,7 @@ class FlashNeoXSharded(FlashCausalLM): filenames, device=device, dtype=dtype, process_group=self.process_group ) if config.quantize == "gptq": - weights._set_gptq_params(model_id) + weights._set_gptq_params(model_id, revision) model = FlashGPTNeoXForCausalLM(config, weights) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 195b3883..dfab8888 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -62,7 +62,7 @@ class FlashRWSharded(FlashCausalLM): config.quantize = quantize if config.quantize == "gptq": - weights._set_gptq_params(model_id) + weights._set_gptq_params(model_id, revision) model = FlashRWForCausalLM(config, weights) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 29505902..22171ec0 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -63,7 +63,7 @@ class FlashSantacoderSharded(FlashCausalLM): aliases={"transformer.wte.weight": ["lm_head.weight"]}, ) if config.quantize == "gptq": - weights._set_gptq_params(model_id) + weights._set_gptq_params(model_id, revision) model = FlashSantacoderForCausalLM(config, weights) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index b296c96e..42ff1c80 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -199,7 +199,7 @@ class GalacticaSharded(CausalLM): filenames, device=device, dtype=dtype, process_group=self.process_group ) if config.quantize == "gptq": - weights._set_gptq_params(model_id) + weights._set_gptq_params(model_id, revision) model = OPTForCausalLM(config, weights) diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index d4c64dfe..45df4839 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -57,7 +57,7 @@ class GPTNeoxSharded(CausalLM): filenames, device=device, dtype=dtype, process_group=self.process_group ) if config.quantize == "gptq": - weights._set_gptq_params(model_id) + weights._set_gptq_params(model_id, revision) model = GPTNeoxForCausalLM(config, weights) diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 19de497c..e419467f 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -81,7 +81,7 @@ class MPTSharded(CausalLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) if config.quantize == "gptq": - weights._set_gptq_params(model_id) + weights._set_gptq_params(model_id, revision) config.quantize = quantize model = MPTForCausalLM(config, weights) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index b2b87246..58fb212f 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -55,7 +55,7 @@ class OPTSharded(CausalLM): filenames, device=device, dtype=dtype, process_group=self.process_group ) if config.quantize == "gptq": - weights._set_gptq_params(model_id) + weights._set_gptq_params(model_id, revision) model = OPTForCausalLM(config, weights) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 802c1a90..67fda511 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -327,13 +327,15 @@ class Weights: return bits, groupsize - def _set_gptq_params(self, model_id): + def _set_gptq_params(self, model_id, revision): filename = "config.json" try: if os.path.exists(os.path.join(model_id, filename)): filename = os.path.join(model_id, filename) else: - filename = hf_hub_download(model_id, filename=filename) + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) with open(filename, "r") as f: data = json.load(f) self.gptq_bits = data["quantization_config"]["bits"] @@ -344,7 +346,9 @@ class Weights: if os.path.exists(os.path.join(model_id, filename)): filename = os.path.join(model_id, filename) else: - filename = hf_hub_download(model_id, filename=filename) + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) with open(filename, "r") as f: data = json.load(f) self.gptq_bits = data["bits"] @@ -355,7 +359,9 @@ class Weights: if os.path.exists(os.path.join(model_id, filename)): filename = os.path.join(model_id, filename) else: - filename = hf_hub_download(model_id, filename=filename) + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) with open(filename, "r") as f: data = json.load(f) self.gptq_bits = data["w_bit"]