diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 497c8f50..0bb0d1a4 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -154,17 +154,6 @@ def download_weights( try: import json - # TODO remove this, there's currently logic to avoid detecting this file - # as being enough to say the entire repo is on disk, since we also need the parent model - # We're keeping this potential download to prevent breaking backward compatibilty, - # but it shouldn't be necessary in the flow here. - try: - medusa_head = hf_hub_download( - model_id, revision=revision, filename="medusa_lm_head.safetensors" - ) - except Exception: - pass - config = hf_hub_download( model_id, revision=revision, filename="config.json" ) @@ -176,17 +165,17 @@ def download_weights( revision = "main" try: utils.weight_files(base_model_id, revision, extension) - logger.info( - f"Files for parent {base_model_id} are already present on the host. " - "Skipping download." - ) - return # Local files not found except ( utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError, ): + logger.info(f"Downloading parent model {base_model_id}") + filenames = utils.weight_hub_files( + base_model_id, revision, extension + ) + utils.download_weights(filenames, base_model_id, revision) pass except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass @@ -235,7 +224,14 @@ def download_weights( return # Local files not found except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): - pass + try: + logger.info(f"Downloading parent model {base_model_id}") + filenames = utils.weight_hub_files( + base_model_id, revision, extension + ) + utils.download_weights(filenames, base_model_id, revision) + except Exception: + pass except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index a81e659d..b56484f6 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -40,7 +40,6 @@ def _weight_hub_files_from_model_info( and "arguments" not in s.rfilename and "args" not in s.rfilename and "training" not in s.rfilename - and "medusa_lm_head" not in s.rfilename ] @@ -57,7 +56,6 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: and "args" not in f and "adapter" not in f and "training" not in f - and "medusa_lm_head" not in f ] return filenames