mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
T5?
This commit is contained in:
parent
680f26d6b2
commit
e36e42a3f4
@ -729,7 +729,6 @@ class T5PreTrainedModel(PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = T5Config
|
config_class = T5Config
|
||||||
base_model_prefix = "transformer"
|
|
||||||
|
|
||||||
def _shift_right(self, input_ids):
|
def _shift_right(self, input_ids):
|
||||||
decoder_start_token_id = self.config.decoder_start_token_id
|
decoder_start_token_id = self.config.decoder_start_token_id
|
||||||
@ -1021,7 +1020,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
embed_tokens=self.shared,
|
embed_tokens=self.shared,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.lm_head = TensorParallelHead.load(config, prefix="shared", weights=weights)
|
self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -35,7 +35,9 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
config = AutoConfig.from_pretrained(model_id, revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
Loading…
Reference in New Issue
Block a user