From 529d7c259128f025f34c3493ba8353f43db9762c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 21 Dec 2023 17:29:23 +0100 Subject: [PATCH] Fix local load for peft (#1373) local directory overloaded still needs the directory to locate the weights files correctly. # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/tests/utils/test_hub.py | 10 +++++++++- server/text_generation_server/utils/hub.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/server/tests/utils/test_hub.py b/server/tests/utils/test_hub.py index 49549893..721820f5 100644 --- a/server/tests/utils/test_hub.py +++ b/server/tests/utils/test_hub.py @@ -61,7 +61,15 @@ def test_weight_hub_files_offline_error(offline, fresh_cache): def test_weight_hub_files_offline_ok(prefetched, offline): # If the model is prefetched then we should be able to get the weight files from local cache filenames = weight_hub_files(prefetched) - assert filenames == ["model.safetensors"] + root = None + assert len(filenames) == 1 + for f in filenames: + curroot, filename = os.path.split(f) + if root is None: + root = curroot + else: + assert root == curroot + assert filename == "model.safetensors" def test_weight_hub_files(): diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index deb1a941..b56484f6 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -49,7 +49,7 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: # done there with the len(s.rfilename.split("/")) == 1 condition root, _, files = next(os.walk(str(d))) filenames = [ - f + os.path.join(root, f) for f in files if f.endswith(extension) and "arguments" not in f