diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index bb0963d4..ba45916c 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -154,13 +154,21 @@ def download_weights( try: import json - medusa_head = hf_hub_download( - model_id, revision=revision, filename="medusa_lm_head.safetensors" - ) - medusa_config = hf_hub_download( + # TODO remove this, there's currently logic to avoid detecting this file + # as being enough to say the entire repo is on disk, since we also need the parent model + # We're keeping this potential download to prevent breaking backward compatibilty, + # but it shouldn't be necessary in the flow here. + try: + medusa_head = hf_hub_download( + model_id, revision=revision, filename="medusa_lm_head.safetensors" + ) + except Exception: + pass + + config = hf_hub_download( model_id, revision=revision, filename="config.json" ) - with open(medusa_config, "r") as f: + with open(config, "r") as f: config = json.load(f) model_id = config["base_model_name_or_path"] @@ -195,14 +203,23 @@ def download_weights( if not extension == ".safetensors" or not auto_convert: raise e - elif (Path(model_id) / "medusa_lm_head.safetensors").exists(): + elif (Path(model_id) / "adapter_config.json").exists(): + # Try to load as a local PEFT model + try: + utils.download_and_unload_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + utils.weight_files(model_id, revision, extension) + return + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + elif (Path(model_id) / "config.json").exists(): # Try to load as a local Medusa model try: import json - medusa_head = Path(model_id) / "medusa_lm_head.safetensors" - medusa_config = Path(model_id) / "config.json" - with open(medusa_config, "r") as f: + config = Path(model_id) / "config.json" + with open(config, "r") as f: config = json.load(f) model_id = config["base_model_name_or_path"] @@ -220,17 +237,6 @@ def download_weights( except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass - elif (Path(model_id) / "adapter_config.json").exists(): - # Try to load as a local PEFT model - try: - utils.download_and_unload_peft( - model_id, revision, trust_remote_code=trust_remote_code - ) - utils.weight_files(model_id, revision, extension) - return - except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): - pass - # Try to see if there are local pytorch weights try: # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE