From f871f114ca5f5a18a2a4a2c7658aed87440d381f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sat, 18 May 2024 13:31:24 +0200 Subject: [PATCH] Fixing the download strategy for ibm-fms (#1917) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/text_generation_server/cli.py | 96 +++++++++++----------- server/text_generation_server/utils/hub.py | 2 - 2 files changed, 48 insertions(+), 50 deletions(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index bb0963d4..ad623ccc 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -154,31 +154,27 @@ def download_weights( try: import json - medusa_head = hf_hub_download( - model_id, revision=revision, filename="medusa_lm_head.safetensors" - ) - medusa_config = hf_hub_download( + config = hf_hub_download( model_id, revision=revision, filename="config.json" ) - with open(medusa_config, "r") as f: + with open(config, "r") as f: config = json.load(f) - model_id = config["base_model_name_or_path"] - revision = "main" - try: - utils.weight_files(model_id, revision, extension) - logger.info( - f"Files for parent {model_id} are already present on the host. " - "Skipping download." - ) - return - # Local files not found - except ( - utils.LocalEntryNotFoundError, - FileNotFoundError, - utils.EntryNotFoundError, - ): - pass + base_model_id = config.get("base_model_name_or_path", None) + if base_model_id and base_model_id != model_id: + try: + logger.info(f"Downloading parent model {base_model_id}") + download_weights( + model_id=base_model_id, + revision="main", + extension=extension, + auto_convert=auto_convert, + logger_level=logger_level, + json_output=json_output, + trust_remote_code=trust_remote_code, + ) + except Exception: + pass except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass @@ -195,31 +191,6 @@ def download_weights( if not extension == ".safetensors" or not auto_convert: raise e - elif (Path(model_id) / "medusa_lm_head.safetensors").exists(): - # Try to load as a local Medusa model - try: - import json - - medusa_head = Path(model_id) / "medusa_lm_head.safetensors" - medusa_config = Path(model_id) / "config.json" - with open(medusa_config, "r") as f: - config = json.load(f) - - model_id = config["base_model_name_or_path"] - revision = "main" - try: - utils.weight_files(model_id, revision, extension) - logger.info( - f"Files for parent {model_id} are already present on the host. " - "Skipping download." - ) - return - # Local files not found - except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): - pass - except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): - pass - elif (Path(model_id) / "adapter_config.json").exists(): # Try to load as a local PEFT model try: @@ -230,14 +201,43 @@ def download_weights( return except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass + elif (Path(model_id) / "config.json").exists(): + # Try to load as a local Medusa model + try: + import json + + config = Path(model_id) / "config.json" + with open(config, "r") as f: + config = json.load(f) + + base_model_id = config.get("base_model_name_or_path", None) + if base_model_id: + try: + logger.info(f"Downloading parent model {base_model_id}") + download_weights( + model_id=base_model_id, + revision="main", + extension=extension, + auto_convert=auto_convert, + logger_level=logger_level, + json_output=json_output, + trust_remote_code=trust_remote_code, + ) + except Exception: + pass + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass # Try to see if there are local pytorch weights try: # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE - local_pt_files = utils.weight_files(model_id, revision, ".bin") + try: + local_pt_files = utils.weight_files(model_id, revision, ".bin") + except Exception: + local_pt_files = utils.weight_files(model_id, revision, ".pt") # No local pytorch weights - except utils.LocalEntryNotFoundError: + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): if extension == ".safetensors": logger.warning( f"No safetensors weights found for model {model_id} at revision {revision}. " diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index a81e659d..b56484f6 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -40,7 +40,6 @@ def _weight_hub_files_from_model_info( and "arguments" not in s.rfilename and "args" not in s.rfilename and "training" not in s.rfilename - and "medusa_lm_head" not in s.rfilename ] @@ -57,7 +56,6 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: and "args" not in f and "adapter" not in f and "training" not in f - and "medusa_lm_head" not in f ] return filenames