This commit is contained in:
Ubuntu 2023-05-24 11:53:09 +00:00 committed by Nicolas Patry
parent 680f26d6b2
commit e36e42a3f4
2 changed files with 4 additions and 3 deletions

View File

@ -729,7 +729,6 @@ class T5PreTrainedModel(PreTrainedModel):
""" """
config_class = T5Config config_class = T5Config
base_model_prefix = "transformer"
def _shift_right(self, input_ids): def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id
@ -1021,7 +1020,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
embed_tokens=self.shared, embed_tokens=self.shared,
) )
self.lm_head = TensorParallelHead.load(config, prefix="shared", weights=weights) self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights)
def forward( def forward(
self, self,

View File

@ -35,7 +35,9 @@ class T5Sharded(Seq2SeqLM):
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
config = AutoConfig.from_pretrained(model_id, revision=revision) config = AutoConfig.from_pretrained(model_id, revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize config.quantize = quantize
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(