diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index d6cccd44..2f28c4ce 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -423,7 +423,7 @@ class Seq2SeqLM(Model): ) # Append next token to decoder tokens - decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) + decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)]) new_decoder_input_length = decoder_input_length + 1 # Generated token