mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fix shared weights load bug and T5 loading
This commit is contained in:
parent
c5de7cd886
commit
57433201b2
@ -63,6 +63,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
"shared.weight": [
|
"shared.weight": [
|
||||||
"encoder.embed_tokens.weight",
|
"encoder.embed_tokens.weight",
|
||||||
"decoder.embed_tokens.weight",
|
"decoder.embed_tokens.weight",
|
||||||
|
"lm_head.weight",
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -45,8 +45,8 @@ class Weights:
|
|||||||
def get_filename(self, tensor_name: str) -> (str, str):
|
def get_filename(self, tensor_name: str) -> (str, str):
|
||||||
filename = self.routing.get(tensor_name, None)
|
filename = self.routing.get(tensor_name, None)
|
||||||
if filename is None:
|
if filename is None:
|
||||||
aliases = self.aliases.get(tensor_name, [])
|
for alias, tensor_list in self.aliases.items():
|
||||||
for alias in aliases:
|
if tensor_name in tensor_list:
|
||||||
filename = self.routing.get(alias, None)
|
filename = self.routing.get(alias, None)
|
||||||
if filename is not None:
|
if filename is not None:
|
||||||
return str(filename), alias
|
return str(filename), alias
|
||||||
|
Loading…
Reference in New Issue
Block a user