mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-08 10:52:14 +00:00
fix: check start_states len and add states_to_token_maps to filter
This commit is contained in:
parent
5b5cbb14d6
commit
65b5b4c36e
@ -54,7 +54,7 @@ class NextTokenChooser:
|
|||||||
)
|
)
|
||||||
self.grammar_processor = (
|
self.grammar_processor = (
|
||||||
GrammarLogitProcessor(tokenizer, device, states_to_token_maps)
|
GrammarLogitProcessor(tokenizer, device, states_to_token_maps)
|
||||||
if states_to_token_maps
|
if len(states_to_token_maps.start_states) > 0
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@ -264,7 +264,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
self.grammar_processor = (
|
self.grammar_processor = (
|
||||||
HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps)
|
HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps)
|
||||||
if any(states_to_token_maps)
|
if any([len(x.start_states) > 0 for x in states_to_token_maps])
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -434,10 +434,13 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.seeds = [self.seeds[i] for i in indices]
|
self.seeds = [self.seeds[i] for i in indices]
|
||||||
self.do_sample = [self.do_sample[i] for i in indices]
|
self.do_sample = [self.do_sample[i] for i in indices]
|
||||||
|
|
||||||
|
new_states_to_token_maps = []
|
||||||
new_fsm_grammar_states = []
|
new_fsm_grammar_states = []
|
||||||
for i in indices:
|
for i in indices:
|
||||||
|
new_states_to_token_maps.append(self.states_to_token_maps[i])
|
||||||
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
|
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
|
||||||
|
|
||||||
|
self.states_to_token_maps = new_states_to_token_maps
|
||||||
self.fsm_grammar_states = new_fsm_grammar_states
|
self.fsm_grammar_states = new_fsm_grammar_states
|
||||||
|
|
||||||
if any(self.do_sample):
|
if any(self.do_sample):
|
||||||
|
Loading…
Reference in New Issue
Block a user