From 3fa6a4e6741f8ce9950e2c3151f4d396a2193f75 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 10 Feb 2023 19:25:07 +0100 Subject: [PATCH] fix seq2seq --- server/text_generation/models/seq2seq_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index d6cccd44..2f28c4ce 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -423,7 +423,7 @@ class Seq2SeqLM(Model): ) # Append next token to decoder tokens - decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) + decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)]) new_decoder_input_length = decoder_input_length + 1 # Generated token