Fix shared weights load bug and T5 loading

This commit is contained in:
zhangsibo1129 2023-09-26 17:57:59 +08:00
parent c5de7cd886
commit 57433201b2
2 changed files with 4 additions and 3 deletions

View File

@ -63,6 +63,7 @@ class T5Sharded(Seq2SeqLM):
"shared.weight": [ "shared.weight": [
"encoder.embed_tokens.weight", "encoder.embed_tokens.weight",
"decoder.embed_tokens.weight", "decoder.embed_tokens.weight",
"lm_head.weight",
] ]
}, },
) )

View File

@ -45,9 +45,9 @@ class Weights:
def get_filename(self, tensor_name: str) -> (str, str): def get_filename(self, tensor_name: str) -> (str, str):
filename = self.routing.get(tensor_name, None) filename = self.routing.get(tensor_name, None)
if filename is None: if filename is None:
aliases = self.aliases.get(tensor_name, []) for alias, tensor_list in self.aliases.items():
for alias in aliases: if tensor_name in tensor_list:
filename = self.routing.get(alias, None) filename = self.routing.get(alias, None)
if filename is not None: if filename is not None:
return str(filename), alias return str(filename), alias
raise RuntimeError(f"weight {tensor_name} does not exist") raise RuntimeError(f"weight {tensor_name} does not exist")