diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 301acb6b..887b5732 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -123,6 +123,7 @@ def download_weights( "WEIGHTS_CACHE_OVERRIDE", None ) is not None + print(f"is_local_model: {is_local_model}") if not is_local_model: try: adapter_config_filename = hf_hub_download( @@ -149,6 +150,23 @@ def download_weights( # Check if we want to automatically convert to safetensors or if we can use .bin weights instead if not extension == ".safetensors" or not auto_convert: raise e + + # Try to load as a PEFT model + # Newly added + try: + + # adapter_config_filename = hf_hub_download( + # model_id, revision=revision, filename="adapter_config.json" + # ) + + utils.download_and_unload_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + utils.weight_files(model_id, revision, extension) + return + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + # Try to see if there are local pytorch weights try: diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 23743c9b..62eaf440 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -83,8 +83,11 @@ def weight_files( model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" ) -> List[Path]: """Get the local files""" + print(f"weight_files called with model_id: {model_id} revision: {revision} extension: {extension}") + # Local model if Path(model_id).exists() and Path(model_id).is_dir(): + print(f"Finding local files with extension: {extension}") local_files = list(Path(model_id).glob(f"*{extension}")) if not local_files: raise FileNotFoundError( diff --git a/server/text_generation_server/utils/peft.py b/server/text_generation_server/utils/peft.py index e37447dc..66defcfc 100644 --- a/server/text_generation_server/utils/peft.py +++ b/server/text_generation_server/utils/peft.py @@ -8,6 +8,8 @@ from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM def download_and_unload_peft(model_id, revision, trust_remote_code): + print(f"download_and_unload_peft called with model_id: {model_id} revision: {revision} tmc: {trust_remote_code}") + torch_dtype = torch.float16 logger.info("Peft model detected.") @@ -35,6 +37,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): model = model.merge_and_unload() + print(f"Creating dir: {model_id}") os.makedirs(model_id, exist_ok=True) cache_dir = model_id logger.info(f"Saving the newly created merged model to {cache_dir}")