diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 93cd7ba0..72c6c21c 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -341,7 +341,7 @@ class HeterogeneousNextTokenChooser: for warper in self.warpers: _scores = warper(input_ids, _scores) if self.grammar_processor is not None: - _scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask) + _scores = self.grammar_processor(_scores, self.fsm_grammar_states) _next_ids = self.choice(_scores) scores[:, j] = _scores next_ids[:, j] = _next_ids @@ -402,7 +402,7 @@ class HeterogeneousNextTokenChooser: def advance_grammar(self, next_ids: List[int]): if self.grammar_processor is not None: other_new_states = self.grammar_processor.advance_batch( - next_ids, self.fsm_grammar_states, self.grammars + next_ids, self.fsm_grammar_states ) self.fsm_grammar_states = other_new_states return self