From 65b5b4c36e131456674132f1d595304d7aacba34 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 8 Mar 2024 22:58:30 +0000 Subject: [PATCH] fix: check start_states len and add states_to_token_maps to filter --- server/text_generation_server/utils/tokens.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 07c8e020..24efe62c 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -54,7 +54,7 @@ class NextTokenChooser: ) self.grammar_processor = ( GrammarLogitProcessor(tokenizer, device, states_to_token_maps) - if states_to_token_maps + if len(states_to_token_maps.start_states) > 0 else None ) self.tokenizer = tokenizer @@ -264,7 +264,7 @@ class HeterogeneousNextTokenChooser: self.grammar_processor = ( HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps) - if any(states_to_token_maps) + if any([len(x.start_states) > 0 for x in states_to_token_maps]) else None ) @@ -434,10 +434,13 @@ class HeterogeneousNextTokenChooser: self.seeds = [self.seeds[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices] + new_states_to_token_maps = [] new_fsm_grammar_states = [] for i in indices: + new_states_to_token_maps.append(self.states_to_token_maps[i]) new_fsm_grammar_states.append(self.fsm_grammar_states[i]) + self.states_to_token_maps = new_states_to_token_maps self.fsm_grammar_states = new_fsm_grammar_states if any(self.do_sample):