Fix shared weights load bug and T5 loading

This commit is contained in:
zhangsibo1129 2023-09-26 17:57:59 +08:00
parent c5de7cd886
commit 57433201b2
2 changed files with 4 additions and 3 deletions

View File

@ -63,6 +63,7 @@ class T5Sharded(Seq2SeqLM):
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
"lm_head.weight",
]
},
)

View File

@ -45,8 +45,8 @@ class Weights:
def get_filename(self, tensor_name: str) -> (str, str):
filename = self.routing.get(tensor_name, None)
if filename is None:
aliases = self.aliases.get(tensor_name, [])
for alias in aliases:
for alias, tensor_list in self.aliases.items():
if tensor_name in tensor_list:
filename = self.routing.get(alias, None)
if filename is not None:
return str(filename), alias