mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
T5?
This commit is contained in:
parent
680f26d6b2
commit
e36e42a3f4
@ -729,7 +729,6 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
||||
config_class = T5Config
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
def _shift_right(self, input_ids):
|
||||
decoder_start_token_id = self.config.decoder_start_token_id
|
||||
@ -1021,7 +1020,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
embed_tokens=self.shared,
|
||||
)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(config, prefix="shared", weights=weights)
|
||||
self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -35,7 +35,9 @@ class T5Sharded(Seq2SeqLM):
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||
config = AutoConfig.from_pretrained(model_id, revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
Loading…
Reference in New Issue
Block a user