diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 07c8e020..24efe62c 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -54,7 +54,7 @@ class NextTokenChooser: ) self.grammar_processor = ( GrammarLogitProcessor(tokenizer, device, states_to_token_maps) - if states_to_token_maps + if len(states_to_token_maps.start_states) > 0 else None ) self.tokenizer = tokenizer @@ -264,7 +264,7 @@ class HeterogeneousNextTokenChooser: self.grammar_processor = ( HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps) - if any(states_to_token_maps) + if any([len(x.start_states) > 0 for x in states_to_token_maps]) else None ) @@ -434,10 +434,13 @@ class HeterogeneousNextTokenChooser: self.seeds = [self.seeds[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices] + new_states_to_token_maps = [] new_fsm_grammar_states = [] for i in indices: + new_states_to_token_maps.append(self.states_to_token_maps[i]) new_fsm_grammar_states.append(self.fsm_grammar_states[i]) + self.states_to_token_maps = new_states_to_token_maps self.fsm_grammar_states = new_fsm_grammar_states if any(self.do_sample):