From 5573f229c8c753fce904478981345fccc684148b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 Jun 2023 14:03:29 +0000 Subject: [PATCH] Fixing T5 in case the names are mixed up. --- .../models/custom_modeling/t5_modeling.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 51862e3c..12679e9d 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -1001,7 +1001,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel): super().__init__(config) self.model_dim = config.d_model - self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) + try: + self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) + except RuntimeError: + self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False