fix: check start_states len and add states_to_token_maps to filter

This commit is contained in:
drbh 2024-03-08 22:58:30 +00:00
parent 5b5cbb14d6
commit 65b5b4c36e

View File

@ -54,7 +54,7 @@ class NextTokenChooser:
)
self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device, states_to_token_maps)
if states_to_token_maps
if len(states_to_token_maps.start_states) > 0
else None
)
self.tokenizer = tokenizer
@ -264,7 +264,7 @@ class HeterogeneousNextTokenChooser:
self.grammar_processor = (
HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps)
if any(states_to_token_maps)
if any([len(x.start_states) > 0 for x in states_to_token_maps])
else None
)
@ -434,10 +434,13 @@ class HeterogeneousNextTokenChooser:
self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices]
new_states_to_token_maps = []
new_fsm_grammar_states = []
for i in indices:
new_states_to_token_maps.append(self.states_to_token_maps[i])
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
self.states_to_token_maps = new_states_to_token_maps
self.fsm_grammar_states = new_fsm_grammar_states
if any(self.do_sample):