From da27fbdfdbaaab4157e0aa6114c16503471c499c Mon Sep 17 00:00:00 2001 From: PYNing <540439329@qq.com> Date: Thu, 11 Jan 2024 01:36:20 +0800 Subject: [PATCH] Fix local load for Medusa (#1420) # What does this PR do? Close #1418 Close #1415 ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/text_generation_server/cli.py | 29 +++++++++++++++++++ .../models/flash_llama.py | 25 +++++++++++----- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 403f46e7..99be6c7e 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -198,6 +198,35 @@ def download_weights( if not extension == ".safetensors" or not auto_convert: raise e + elif (Path(model_id) / "medusa_lm_head.pt").exists(): + # Try to load as a local Medusa model + try: + import json + + medusa_head = Path(model_id) / "medusa_lm_head.pt" + if auto_convert: + medusa_sf = Path(model_id) / "medusa_lm_head.safetensors" + if not medusa_sf.exists(): + utils.convert_files([Path(medusa_head)], [medusa_sf], []) + medusa_config = Path(model_id) / "config.json" + with open(medusa_config, "r") as f: + config = json.load(f) + + model_id = config["base_model_name_or_path"] + revision = "main" + try: + utils.weight_files(model_id, revision, extension) + logger.info( + f"Files for parent {model_id} are already present on the host. " + "Skipping download." + ) + return + # Local files not found + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + elif (Path(model_id) / "adapter_config.json").exists(): # Try to load as a local PEFT model try: diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 8a3bccdd..7be61906 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -71,15 +71,26 @@ class FlashLlama(FlashCausalLM): from text_generation_server.utils.medusa import MedusaModel from huggingface_hub import hf_hub_download import json - - medusa_config = hf_hub_download( - use_medusa, revision=revision, filename="config.json" - ) + import os + from pathlib import Path + + is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv( + "WEIGHTS_CACHE_OVERRIDE", None + ) is not None + + if not is_local_model: + medusa_config = hf_hub_download( + use_medusa, revision=revision, filename="config.json" + ) + medusa_head = hf_hub_download( + use_medusa, revision=revision, filename="medusa_lm_head.pt" + ) + else: + medusa_config = str(Path(use_medusa) / "config.json") + medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") + with open(medusa_config, "r") as f: config = json.load(f) - medusa_head = hf_hub_download( - use_medusa, revision=revision, filename="medusa_lm_head.pt" - ) medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" weights = Weights( [medusa_sf], device, dtype, process_group=self.process_group