From ff7774f2f1b42f5310f66d46a2e112c4a2eef3ec Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 24 Feb 2023 16:56:34 +0100 Subject: [PATCH] fix(server): fix token_is_special --- server/text_generation/models/causal_lm.py | 2 +- server/text_generation/models/seq2seq_lm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index d15197d0..30aff87f 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -445,7 +445,7 @@ class CausalLM(Model): next_token_id_squeezed, next_token_logprob, next_token_text, - next_token_id_squeezed in self.all_special_ids, + next_token_id_squeezed.item() in self.all_special_ids, generated_text, ) diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 3a4108ab..3738d7ab 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -509,7 +509,7 @@ class Seq2SeqLM(Model): next_token_id_squeezed, next_token_logprob, next_token_text, - next_token_id_squeezed in self.all_special_ids, + next_token_id_squeezed.item() in self.all_special_ids, generated_text, )