From 85acb11ba0ed9de0bdc18047478adaaa041baacb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 3 Oct 2023 11:55:10 +0200 Subject: [PATCH] Handling bloom prefix. (#1090) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/text_generation_server/models/bloom.py | 2 +- .../text_generation_server/utils/weights.py | 19 ++++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 0151b017..8e8daad3 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -74,7 +74,7 @@ class BLOOMSharded(CausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer", ) if config.quantize == "gptq": weights._set_gptq_params(model_id) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 8a19fd9f..4bae8cc0 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -16,6 +16,7 @@ class Weights: dtype, process_group, aliases: Optional[Dict[str, List[str]]] = None, + prefix: Optional[str] = None ): routing = {} for filename in filenames: @@ -33,6 +34,7 @@ class Weights: self.device = device self.dtype = dtype self.process_group = process_group + self.prefix = prefix self._handles = {} def _get_handle(self, filename): @@ -43,15 +45,22 @@ class Weights: return self._handles[filename] def get_filename(self, tensor_name: str) -> (str, str): - filename = self.routing.get(tensor_name, None) - if filename is None: - aliases = self.aliases.get(tensor_name, []) + + names = [tensor_name] + if self.prefix is not None: + prefixed = f"{self.prefix}.{tensor_name}" + names.append(prefixed) + for name in names: + filename = self.routing.get(name, None) + if filename is not None: + return str(filename), name + + aliases = self.aliases.get(name, []) for alias in aliases: filename = self.routing.get(alias, None) if filename is not None: return str(filename), alias - raise RuntimeError(f"weight {tensor_name} does not exist") - return str(filename), tensor_name + raise RuntimeError(f"weight {tensor_name} does not exist") def _get_slice(self, tensor_name: str): filename, tensor_name = self.get_filename(tensor_name)