mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix
This commit is contained in:
parent
a337182b43
commit
346bc74acd
@ -341,7 +341,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
_scores = warper(input_ids, _scores)
|
_scores = warper(input_ids, _scores)
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
_scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask)
|
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
||||||
_next_ids = self.choice(_scores)
|
_next_ids = self.choice(_scores)
|
||||||
scores[:, j] = _scores
|
scores[:, j] = _scores
|
||||||
next_ids[:, j] = _next_ids
|
next_ids[:, j] = _next_ids
|
||||||
@ -402,7 +402,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
def advance_grammar(self, next_ids: List[int]):
|
def advance_grammar(self, next_ids: List[int]):
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
other_new_states = self.grammar_processor.advance_batch(
|
other_new_states = self.grammar_processor.advance_batch(
|
||||||
next_ids, self.fsm_grammar_states, self.grammars
|
next_ids, self.fsm_grammar_states
|
||||||
)
|
)
|
||||||
self.fsm_grammar_states = other_new_states
|
self.fsm_grammar_states = other_new_states
|
||||||
return self
|
return self
|
||||||
|
Loading…
Reference in New Issue
Block a user