fix seq2seq

This commit is contained in:
OlivierDehaene 2023-02-10 19:25:07 +01:00
parent e9441a1ea2
commit 3fa6a4e674

View File

@ -423,7 +423,7 @@ class Seq2SeqLM(Model):
) )
# Append next token to decoder tokens # Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
new_decoder_input_length = decoder_input_length + 1 new_decoder_input_length = decoder_input_length + 1
# Generated token # Generated token