mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-08 10:52:14 +00:00
fix: check start_states len and add states_to_token_maps to filter
This commit is contained in:
parent
5b5cbb14d6
commit
65b5b4c36e
@ -54,7 +54,7 @@ class NextTokenChooser:
|
||||
)
|
||||
self.grammar_processor = (
|
||||
GrammarLogitProcessor(tokenizer, device, states_to_token_maps)
|
||||
if states_to_token_maps
|
||||
if len(states_to_token_maps.start_states) > 0
|
||||
else None
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
@ -264,7 +264,7 @@ class HeterogeneousNextTokenChooser:
|
||||
|
||||
self.grammar_processor = (
|
||||
HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps)
|
||||
if any(states_to_token_maps)
|
||||
if any([len(x.start_states) > 0 for x in states_to_token_maps])
|
||||
else None
|
||||
)
|
||||
|
||||
@ -434,10 +434,13 @@ class HeterogeneousNextTokenChooser:
|
||||
self.seeds = [self.seeds[i] for i in indices]
|
||||
self.do_sample = [self.do_sample[i] for i in indices]
|
||||
|
||||
new_states_to_token_maps = []
|
||||
new_fsm_grammar_states = []
|
||||
for i in indices:
|
||||
new_states_to_token_maps.append(self.states_to_token_maps[i])
|
||||
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
|
||||
|
||||
self.states_to_token_maps = new_states_to_token_maps
|
||||
self.fsm_grammar_states = new_fsm_grammar_states
|
||||
|
||||
if any(self.do_sample):
|
||||
|
Loading…
Reference in New Issue
Block a user