Fixing t5 loading.

This commit is contained in:
Nicolas Patry 2023-09-21 08:29:48 +02:00
parent 123749a3c9
commit e19d0e7867

View File

@ -1033,7 +1033,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
) )
self.lm_head = TensorParallelHead.load( self.lm_head = TensorParallelHead.load(
config, prefix="lm_head", weights=weights config, prefix="shared", weights=weights
) )
def forward( def forward(