diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 0bb0d1a4..ad623ccc 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -161,21 +161,19 @@ def download_weights( config = json.load(f) base_model_id = config.get("base_model_name_or_path", None) - if base_model_id: - revision = "main" + if base_model_id and base_model_id != model_id: try: - utils.weight_files(base_model_id, revision, extension) - # Local files not found - except ( - utils.LocalEntryNotFoundError, - FileNotFoundError, - utils.EntryNotFoundError, - ): logger.info(f"Downloading parent model {base_model_id}") - filenames = utils.weight_hub_files( - base_model_id, revision, extension + download_weights( + model_id=base_model_id, + revision="main", + extension=extension, + auto_convert=auto_convert, + logger_level=logger_level, + json_output=json_output, + trust_remote_code=trust_remote_code, ) - utils.download_weights(filenames, base_model_id, revision) + except Exception: pass except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass @@ -214,34 +212,32 @@ def download_weights( base_model_id = config.get("base_model_name_or_path", None) if base_model_id: - revision = "main" try: - utils.weight_files(base_model_id, revision, extension) - logger.info( - f"Files for parent {base_model_id} are already present on the host. " - "Skipping download." + logger.info(f"Downloading parent model {base_model_id}") + download_weights( + model_id=base_model_id, + revision="main", + extension=extension, + auto_convert=auto_convert, + logger_level=logger_level, + json_output=json_output, + trust_remote_code=trust_remote_code, ) - return - # Local files not found - except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): - try: - logger.info(f"Downloading parent model {base_model_id}") - filenames = utils.weight_hub_files( - base_model_id, revision, extension - ) - utils.download_weights(filenames, base_model_id, revision) - except Exception: - pass + except Exception: + pass except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass # Try to see if there are local pytorch weights try: # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE - local_pt_files = utils.weight_files(model_id, revision, ".bin") + try: + local_pt_files = utils.weight_files(model_id, revision, ".bin") + except Exception: + local_pt_files = utils.weight_files(model_id, revision, ".pt") # No local pytorch weights - except utils.LocalEntryNotFoundError: + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): if extension == ".safetensors": logger.warning( f"No safetensors weights found for model {model_id} at revision {revision}. "