From c9c65ab323f48731e1fc2f7087547a7bd8b753f2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 20 Jun 2023 18:03:36 +0200 Subject: [PATCH] fix(server): Fixing T5 in case the names are mixed up. (#475) --- .../models/custom_modeling/t5_modeling.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 51862e3c..12679e9d 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -1001,7 +1001,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel): super().__init__(config) self.model_dim = config.d_model - self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) + try: + self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) + except RuntimeError: + self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False