Map deduplicated tensors via metadata

This PR automatically points tensors that were removed due to
deduplication to their still existing twin.

In `server.text_generation_server.utils.convert.py#convert_file`,
duplicated tensors are removed and logged to the "metadata" dictionary.
However, this dictionary was not yet used during loading. This requires
explicit remapping when loading the models (as mentioned in the
docstring).

What does this fix?
We currently cannot load `h2oai/h2ogpt-oig-oasst1-falcon-40b` with the
unmodified server, since the `transformer.word_embeddings.weight` weight
is equal to `lm_head.weight` and is automatically removed.
This commit is contained in:
Vincent Brouwers 2023-06-28 17:18:01 +00:00
parent 70f485bf9f
commit d6bb10f202

View File

@ -7,8 +7,10 @@ import torch
class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
routing = {}
metadata = {}
for filename in filenames:
with safe_open(filename, framework="pytorch") as f:
metadata |= f.metadata()
for k in f.keys():
if k in routing:
raise RuntimeError(
@ -17,6 +19,7 @@ class Weights:
routing[k] = filename
if aliases is None:
aliases = {}
self.metadata = metadata
self.aliases = aliases
self.routing = routing
self.device = device
@ -32,6 +35,8 @@ class Weights:
return self._handles[filename]
def get_filename(self, tensor_name: str) -> (str, str):
if tensor_name in self.metadata:
tensor_name = self.metadata[tensor_name]
filename = self.routing.get(tensor_name, None)
if filename is None:
aliases = self.aliases.get(tensor_name, [])