From d6bb10f2025913dd3f4af8de93bd8efbcf4bd2c2 Mon Sep 17 00:00:00 2001 From: Vincent Brouwers Date: Wed, 28 Jun 2023 17:18:01 +0000 Subject: [PATCH] Map deduplicated tensors via metadata This PR automatically points tensors that were removed due to deduplication to their still existing twin. In `server.text_generation_server.utils.convert.py#convert_file`, duplicated tensors are removed and logged to the "metadata" dictionary. However, this dictionary was not yet used during loading. This requires explicit remapping when loading the models (as mentioned in the docstring). What does this fix? We currently cannot load `h2oai/h2ogpt-oig-oasst1-falcon-40b` with the unmodified server, since the `transformer.word_embeddings.weight` weight is equal to `lm_head.weight` and is automatically removed. --- server/text_generation_server/utils/weights.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 9d371834..e566d66b 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -7,8 +7,10 @@ import torch class Weights: def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None): routing = {} + metadata = {} for filename in filenames: with safe_open(filename, framework="pytorch") as f: + metadata |= f.metadata() for k in f.keys(): if k in routing: raise RuntimeError( @@ -17,6 +19,7 @@ class Weights: routing[k] = filename if aliases is None: aliases = {} + self.metadata = metadata self.aliases = aliases self.routing = routing self.device = device @@ -32,6 +35,8 @@ class Weights: return self._handles[filename] def get_filename(self, tensor_name: str) -> (str, str): + if tensor_name in self.metadata: + tensor_name = self.metadata[tensor_name] filename = self.routing.get(tensor_name, None) if filename is None: aliases = self.aliases.get(tensor_name, [])