mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Map deduplicated tensors via metadata
This PR automatically points tensors that were removed due to deduplication to their still existing twin. In `server.text_generation_server.utils.convert.py#convert_file`, duplicated tensors are removed and logged to the "metadata" dictionary. However, this dictionary was not yet used during loading. This requires explicit remapping when loading the models (as mentioned in the docstring). What does this fix? We currently cannot load `h2oai/h2ogpt-oig-oasst1-falcon-40b` with the unmodified server, since the `transformer.word_embeddings.weight` weight is equal to `lm_head.weight` and is automatically removed.
This commit is contained in:
parent
70f485bf9f
commit
d6bb10f202
@ -7,8 +7,10 @@ import torch
|
|||||||
class Weights:
|
class Weights:
|
||||||
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
|
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
|
||||||
routing = {}
|
routing = {}
|
||||||
|
metadata = {}
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
metadata |= f.metadata()
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
if k in routing:
|
if k in routing:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -17,6 +19,7 @@ class Weights:
|
|||||||
routing[k] = filename
|
routing[k] = filename
|
||||||
if aliases is None:
|
if aliases is None:
|
||||||
aliases = {}
|
aliases = {}
|
||||||
|
self.metadata = metadata
|
||||||
self.aliases = aliases
|
self.aliases = aliases
|
||||||
self.routing = routing
|
self.routing = routing
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -32,6 +35,8 @@ class Weights:
|
|||||||
return self._handles[filename]
|
return self._handles[filename]
|
||||||
|
|
||||||
def get_filename(self, tensor_name: str) -> (str, str):
|
def get_filename(self, tensor_name: str) -> (str, str):
|
||||||
|
if tensor_name in self.metadata:
|
||||||
|
tensor_name = self.metadata[tensor_name]
|
||||||
filename = self.routing.get(tensor_name, None)
|
filename = self.routing.get(tensor_name, None)
|
||||||
if filename is None:
|
if filename is None:
|
||||||
aliases = self.aliases.get(tensor_name, [])
|
aliases = self.aliases.get(tensor_name, [])
|
||||||
|
Loading…
Reference in New Issue
Block a user