mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Fix shared weights load bug and T5 loading
This commit is contained in:
parent
c5de7cd886
commit
57433201b2
@ -63,6 +63,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
"shared.weight": [
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
@ -45,9 +45,9 @@ class Weights:
|
||||
def get_filename(self, tensor_name: str) -> (str, str):
|
||||
filename = self.routing.get(tensor_name, None)
|
||||
if filename is None:
|
||||
aliases = self.aliases.get(tensor_name, [])
|
||||
for alias in aliases:
|
||||
filename = self.routing.get(alias, None)
|
||||
for alias, tensor_list in self.aliases.items():
|
||||
if tensor_name in tensor_list:
|
||||
filename = self.routing.get(alias, None)
|
||||
if filename is not None:
|
||||
return str(filename), alias
|
||||
raise RuntimeError(f"weight {tensor_name} does not exist")
|
||||
|
Loading…
Reference in New Issue
Block a user