fix seq2seq

This commit is contained in:
OlivierDehaene 2023-02-10 19:25:07 +01:00
parent e9441a1ea2
commit 3fa6a4e674

View File

@ -423,7 +423,7 @@ class Seq2SeqLM(Model):
)
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
new_decoder_input_length = decoder_input_length + 1
# Generated token