fix: handle batches with and without grammars (#1676)

This PR correctly handles batches with a mixture of constrained and non
constrained generations.

Currently if batch contains mixed generations the generation will throw
an error because it will incorrectly attempt to constrain a request with
an empty grammar.

We now handled `None` grammars and only apply the mask if needed

Fixes:
https://github.com/huggingface/text-generation-inference/issues/1643
This commit is contained in:
drbh 2024-03-28 12:02:01 -04:00 committed by Karol Damaszke
parent d5ed4c110b
commit 56670398f3

View File

@ -526,6 +526,9 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = []
for grammar, grammar_type in zip(grammars, grammar_types):
if len(grammar) == 0:
self.fsms.append(None)
continue
fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type, grammar, self.tokenizer
)
@ -543,7 +546,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
continue
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask[i, allowed_tokens] = 0
logits += mask
logits[i] += mask[i]
return logits
def advance_batch(self, next_token_ids, fsm_grammar_states):
@ -555,6 +558,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
]
def advance_at_index(self, next_token_id, fsm_grammar_state, index):
if self.fsms[index] is None:
return fsm_grammar_state
return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsms[index]
)