mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Revert "Map deduplicated tensors via metadata"
This reverts commit d6bb10f202
.
This commit is contained in:
parent
d6bb10f202
commit
81f234ec61
@ -7,10 +7,8 @@ import torch
|
||||
class Weights:
|
||||
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
|
||||
routing = {}
|
||||
metadata = {}
|
||||
for filename in filenames:
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
metadata |= f.metadata()
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
raise RuntimeError(
|
||||
@ -19,7 +17,6 @@ class Weights:
|
||||
routing[k] = filename
|
||||
if aliases is None:
|
||||
aliases = {}
|
||||
self.metadata = metadata
|
||||
self.aliases = aliases
|
||||
self.routing = routing
|
||||
self.device = device
|
||||
@ -35,8 +32,6 @@ class Weights:
|
||||
return self._handles[filename]
|
||||
|
||||
def get_filename(self, tensor_name: str) -> (str, str):
|
||||
if tensor_name in self.metadata:
|
||||
tensor_name = self.metadata[tensor_name]
|
||||
filename = self.routing.get(tensor_name, None)
|
||||
if filename is None:
|
||||
aliases = self.aliases.get(tensor_name, [])
|
||||
|
Loading…
Reference in New Issue
Block a user