Revert "Map deduplicated tensors via metadata"

This reverts commit d6bb10f202.
This commit is contained in:
Nicolas Patry 2023-07-04 11:30:35 +02:00
parent d6bb10f202
commit 81f234ec61

View File

@ -7,10 +7,8 @@ import torch
class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
routing = {}
metadata = {}
for filename in filenames:
with safe_open(filename, framework="pytorch") as f:
metadata |= f.metadata()
for k in f.keys():
if k in routing:
raise RuntimeError(
@ -19,7 +17,6 @@ class Weights:
routing[k] = filename
if aliases is None:
aliases = {}
self.metadata = metadata
self.aliases = aliases
self.routing = routing
self.device = device
@ -35,8 +32,6 @@ class Weights:
return self._handles[filename]
def get_filename(self, tensor_name: str) -> (str, str):
if tensor_name in self.metadata:
tensor_name = self.metadata[tensor_name]
filename = self.routing.get(tensor_name, None)
if filename is None:
aliases = self.aliases.get(tensor_name, [])