fix: handle batches with and without grammars

This commit is contained in:
drbh 2024-03-25 23:18:50 +00:00
parent 6c4496a1a3
commit 0cd04fe4f7

View File

@ -555,6 +555,9 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = [] self.fsms = []
for grammar, grammar_type in zip(grammars, grammar_types): for grammar, grammar_type in zip(grammars, grammar_types):
if len(grammar) == 0:
self.fsms.append(None)
continue
fsm = GrammarLogitProcessor._cached_compile_fsm( fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type, grammar, self.tokenizer grammar_type, grammar, self.tokenizer
) )
@ -572,7 +575,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
continue continue
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask[i, allowed_tokens] = 0 mask[i, allowed_tokens] = 0
logits += mask logits[i] += mask[i]
return logits return logits
def advance_batch(self, next_token_ids, fsm_grammar_states): def advance_batch(self, next_token_ids, fsm_grammar_states):
@ -584,6 +587,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
] ]
def advance_at_index(self, next_token_id, fsm_grammar_state, index): def advance_at_index(self, next_token_id, fsm_grammar_state, index):
if self.fsms[index] is None:
return fsm_grammar_state
return GrammarLogitProcessor._advance( return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsms[index] next_token_id, fsm_grammar_state, self.fsms[index]
) )