diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index c5ce9bfc..0a9e3b77 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -729,7 +729,6 @@ class T5PreTrainedModel(PreTrainedModel): """ config_class = T5Config - base_model_prefix = "transformer" def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1021,7 +1020,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): embed_tokens=self.shared, ) - self.lm_head = TensorParallelHead.load(config, prefix="shared", weights=weights) + self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights) def forward( self, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 84465d48..e844c36f 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -35,7 +35,9 @@ class T5Sharded(Seq2SeqLM): device = torch.device("cpu") dtype = torch.float32 - config = AutoConfig.from_pretrained(model_id, revision=revision) + config = AutoConfig.from_pretrained(model_id, revision=revision, + trust_remote_code=trust_remote_code, + ) config.quantize = quantize tokenizer = AutoTokenizer.from_pretrained(