fix: check start_states len and add states_to_token_maps to filter

This commit is contained in:
drbh 2024-03-08 22:58:30 +00:00
parent 5b5cbb14d6
commit 65b5b4c36e

View File

@ -54,7 +54,7 @@ class NextTokenChooser:
) )
self.grammar_processor = ( self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device, states_to_token_maps) GrammarLogitProcessor(tokenizer, device, states_to_token_maps)
if states_to_token_maps if len(states_to_token_maps.start_states) > 0
else None else None
) )
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -264,7 +264,7 @@ class HeterogeneousNextTokenChooser:
self.grammar_processor = ( self.grammar_processor = (
HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps) HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps)
if any(states_to_token_maps) if any([len(x.start_states) > 0 for x in states_to_token_maps])
else None else None
) )
@ -434,10 +434,13 @@ class HeterogeneousNextTokenChooser:
self.seeds = [self.seeds[i] for i in indices] self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices]
new_states_to_token_maps = []
new_fsm_grammar_states = [] new_fsm_grammar_states = []
for i in indices: for i in indices:
new_states_to_token_maps.append(self.states_to_token_maps[i])
new_fsm_grammar_states.append(self.fsm_grammar_states[i]) new_fsm_grammar_states.append(self.fsm_grammar_states[i])
self.states_to_token_maps = new_states_to_token_maps
self.fsm_grammar_states = new_fsm_grammar_states self.fsm_grammar_states = new_fsm_grammar_states
if any(self.do_sample): if any(self.do_sample):