This commit is contained in:
OlivierDehaene 2024-02-16 17:18:15 +01:00
parent a337182b43
commit 346bc74acd

View File

@ -341,7 +341,7 @@ class HeterogeneousNextTokenChooser:
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
if self.grammar_processor is not None: if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask) _scores = self.grammar_processor(_scores, self.fsm_grammar_states)
_next_ids = self.choice(_scores) _next_ids = self.choice(_scores)
scores[:, j] = _scores scores[:, j] = _scores
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids
@ -402,7 +402,7 @@ class HeterogeneousNextTokenChooser:
def advance_grammar(self, next_ids: List[int]): def advance_grammar(self, next_ids: List[int]):
if self.grammar_processor is not None: if self.grammar_processor is not None:
other_new_states = self.grammar_processor.advance_batch( other_new_states = self.grammar_processor.advance_batch(
next_ids, self.fsm_grammar_states, self.grammars next_ids, self.fsm_grammar_states
) )
self.fsm_grammar_states = other_new_states self.fsm_grammar_states = other_new_states
return self return self