mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Better fix.
This commit is contained in:
parent
e19d0e7867
commit
ce8eaaf2be
@ -1032,9 +1032,17 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
embed_tokens=self.shared,
|
||||
)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config, prefix="shared", weights=weights
|
||||
)
|
||||
try:
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config, prefix="lm_head", weights=weights
|
||||
)
|
||||
except RuntimeError:
|
||||
# Some models like t5-small were saved with shared weights unlike flan
|
||||
# Since they are declared as the same arch we have no choice but hope
|
||||
# that this is OK instead of using a proper flag.
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config, prefix="shared", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user