Map deduplicated tensors via metadata

This PR automatically points tensors that were removed due to
deduplication to their still existing twin.

In `server.text_generation_server.utils.convert.py#convert_file`,
duplicated tensors are removed and logged to the "metadata" dictionary.
However, this dictionary was not yet used during loading. This requires
explicit remapping when loading the models (as mentioned in the
docstring).

What does this fix?
We currently cannot load `h2oai/h2ogpt-oig-oasst1-falcon-40b` with the
unmodified server, since the `transformer.word_embeddings.weight` weight
is equal to `lm_head.weight` and is automatically removed.
This commit is contained in:
Vincent Brouwers 2023-06-28 17:18:01 +00:00
parent 70f485bf9f
commit d6bb10f202

View File

@ -7,8 +7,10 @@ import torch
class Weights: class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None): def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
routing = {} routing = {}
metadata = {}
for filename in filenames: for filename in filenames:
with safe_open(filename, framework="pytorch") as f: with safe_open(filename, framework="pytorch") as f:
metadata |= f.metadata()
for k in f.keys(): for k in f.keys():
if k in routing: if k in routing:
raise RuntimeError( raise RuntimeError(
@ -17,6 +19,7 @@ class Weights:
routing[k] = filename routing[k] = filename
if aliases is None: if aliases is None:
aliases = {} aliases = {}
self.metadata = metadata
self.aliases = aliases self.aliases = aliases
self.routing = routing self.routing = routing
self.device = device self.device = device
@ -32,6 +35,8 @@ class Weights:
return self._handles[filename] return self._handles[filename]
def get_filename(self, tensor_name: str) -> (str, str): def get_filename(self, tensor_name: str) -> (str, str):
if tensor_name in self.metadata:
tensor_name = self.metadata[tensor_name]
filename = self.routing.get(tensor_name, None) filename = self.routing.get(tensor_name, None)
if filename is None: if filename is None:
aliases = self.aliases.get(tensor_name, []) aliases = self.aliases.get(tensor_name, [])