mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Better fix.
This commit is contained in:
parent
e19d0e7867
commit
ce8eaaf2be
@ -1032,9 +1032,17 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
embed_tokens=self.shared,
|
embed_tokens=self.shared,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.lm_head = TensorParallelHead.load(
|
try:
|
||||||
config, prefix="shared", weights=weights
|
self.lm_head = TensorParallelHead.load(
|
||||||
)
|
config, prefix="lm_head", weights=weights
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
# Some models like t5-small were saved with shared weights unlike flan
|
||||||
|
# Since they are declared as the same arch we have no choice but hope
|
||||||
|
# that this is OK instead of using a proper flag.
|
||||||
|
self.lm_head = TensorParallelHead.load(
|
||||||
|
config, prefix="shared", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user