diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 459ba8c4..b12a9751 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -171,14 +171,14 @@ def download_weights( for p in local_pt_files ] try: - from transformers import AutoConfig import transformers + import json - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - ) - architecture = config.architectures[0] + + config_filename = hf_hub_download(model_id, revision=revision, filename="config.json") + with open(config_filename, "r") as f: + config = json.load(f) + architecture = config["architectures"][0] class_ = getattr(transformers, architecture)