Better fix.

This commit is contained in:
Nicolas Patry 2023-09-25 09:50:42 +00:00
parent e19d0e7867
commit ce8eaaf2be

View File

@ -1032,6 +1032,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
embed_tokens=self.shared,
)
try:
self.lm_head = TensorParallelHead.load(
config, prefix="lm_head", weights=weights
)
except RuntimeError:
# Some models like t5-small were saved with shared weights unlike flan
# Since they are declared as the same arch we have no choice but hope
# that this is OK instead of using a proper flag.
self.lm_head = TensorParallelHead.load(
config, prefix="shared", weights=weights
)