Fixing T5 in case the names are mixed up.

This commit is contained in:
Ubuntu 2023-06-20 14:03:29 +00:00
parent 53aa9194c8
commit 5573f229c8

View File

@ -1001,7 +1001,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
super().__init__(config) super().__init__(config)
self.model_dim = config.d_model self.model_dim = config.d_model
try:
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
except RuntimeError:
self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights)
encoder_config = copy.deepcopy(config) encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False encoder_config.is_decoder = False