From 3f792254726853feb0ac9355e0def8cd53df4846 Mon Sep 17 00:00:00 2001 From: Alex Weston Date: Fri, 25 Oct 2024 11:21:17 -0400 Subject: [PATCH] Check if allowed tokens is None --- server/text_generation_server/utils/logits_process.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index e79a4bb0..ec2813a1 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -498,9 +498,10 @@ class GrammarLogitProcessor(LogitsProcessor): ): if fsm_grammar_state == -1 or self.fsm is None: return logits - allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) + allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens mask = torch.full_like(logits, -math.inf) - mask[:, allowed_tokens] = 0 + if allowed_tokens is not None: + mask[:, allowed_tokens] = 0 biased_scores = logits + mask return biased_scores @@ -589,7 +590,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): if fsm_grammar_states[i] == -1 or fsm is None: continue allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens - mask[i, allowed_tokens] = 0 + if allowed_tokens is not None: + mask[i, allowed_tokens] = 0 logits[i] += mask[i] return logits