mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Check if allowed tokens is None
This commit is contained in:
parent
803d697d3d
commit
3f79225472
@ -498,9 +498,10 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
):
|
||||
if fsm_grammar_state == -1 or self.fsm is None:
|
||||
return logits
|
||||
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
||||
allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens
|
||||
mask = torch.full_like(logits, -math.inf)
|
||||
mask[:, allowed_tokens] = 0
|
||||
if allowed_tokens is not None:
|
||||
mask[:, allowed_tokens] = 0
|
||||
biased_scores = logits + mask
|
||||
return biased_scores
|
||||
|
||||
@ -589,7 +590,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||
continue
|
||||
allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens
|
||||
mask[i, allowed_tokens] = 0
|
||||
if allowed_tokens is not None:
|
||||
mask[i, allowed_tokens] = 0
|
||||
logits[i] += mask[i]
|
||||
return logits
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user