diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 51862e3c..12679e9d 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -1001,7 +1001,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel): super().__init__(config) self.model_dim = config.d_model - self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) + try: + self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) + except RuntimeError: + self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False