Check if allowed tokens is None

This commit is contained in:
Alex Weston 2024-10-25 11:21:17 -04:00
parent 803d697d3d
commit 3f79225472

View File

@ -498,9 +498,10 @@ class GrammarLogitProcessor(LogitsProcessor):
): ):
if fsm_grammar_state == -1 or self.fsm is None: if fsm_grammar_state == -1 or self.fsm is None:
return logits return logits
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens
mask = torch.full_like(logits, -math.inf) mask = torch.full_like(logits, -math.inf)
mask[:, allowed_tokens] = 0 if allowed_tokens is not None:
mask[:, allowed_tokens] = 0
biased_scores = logits + mask biased_scores = logits + mask
return biased_scores return biased_scores
@ -589,7 +590,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
if fsm_grammar_states[i] == -1 or fsm is None: if fsm_grammar_states[i] == -1 or fsm is None:
continue continue
allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens
mask[i, allowed_tokens] = 0 if allowed_tokens is not None:
mask[i, allowed_tokens] = 0
logits[i] += mask[i] logits[i] += mask[i]
return logits return logits