diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 1ea828020..5779a2748 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -1006,12 +1006,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): super().__init__(config) self.model_dim = config.d_model - try: - self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) - except RuntimeError: - self.shared = TensorParallelEmbedding( - prefix="encoder.embed_tokens", weights=weights - ) + self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 1b7073af8..133aafd80 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases={ + "shared.weight": [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + }, ) model = T5ForConditionalGeneration(config, weights)