diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index ec2813a1..d53f070c 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -501,7 +501,7 @@ class GrammarLogitProcessor(LogitsProcessor): allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens mask = torch.full_like(logits, -math.inf) if allowed_tokens is not None: - mask[:, allowed_tokens] = 0 + mask[:, allowed_tokens] = 0 biased_scores = logits + mask return biased_scores