diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 2caf63c0..ec04a240 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -242,7 +242,7 @@ def download_weights( if not extension == ".safetensors" or not auto_convert: raise e - else: + elif (Path(model_id) / "adapter_config.json").exists(): # Try to load as a local PEFT model try: utils.download_and_unload_peft( diff --git a/server/text_generation_server/utils/peft.py b/server/text_generation_server/utils/peft.py index 45e23320..48ca264b 100644 --- a/server/text_generation_server/utils/peft.py +++ b/server/text_generation_server/utils/peft.py @@ -10,8 +10,7 @@ from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM def download_and_unload_peft(model_id, revision, trust_remote_code): torch_dtype = torch.float16 - logger.info("Peft model detected.") - logger.info("Loading the model it might take a while without feedback") + logger.info("Trying to load a Peft model. It might take a while without feedback") try: model = AutoPeftModelForCausalLM.from_pretrained( model_id, @@ -28,7 +27,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): trust_remote_code=trust_remote_code, low_cpu_mem_usage=True, ) - logger.info(f"Loaded.") + logger.info("Peft model detected.") logger.info(f"Merging the lora weights.") base_model_id = model.peft_config["default"].base_model_name_or_path