mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Check if allowed tokens is None
This commit is contained in:
parent
803d697d3d
commit
3f79225472
@ -498,9 +498,10 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
):
|
):
|
||||||
if fsm_grammar_state == -1 or self.fsm is None:
|
if fsm_grammar_state == -1 or self.fsm is None:
|
||||||
return logits
|
return logits
|
||||||
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens
|
||||||
mask = torch.full_like(logits, -math.inf)
|
mask = torch.full_like(logits, -math.inf)
|
||||||
mask[:, allowed_tokens] = 0
|
if allowed_tokens is not None:
|
||||||
|
mask[:, allowed_tokens] = 0
|
||||||
biased_scores = logits + mask
|
biased_scores = logits + mask
|
||||||
return biased_scores
|
return biased_scores
|
||||||
|
|
||||||
@ -589,7 +590,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||||||
if fsm_grammar_states[i] == -1 or fsm is None:
|
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||||
continue
|
continue
|
||||||
allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens
|
allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens
|
||||||
mask[i, allowed_tokens] = 0
|
if allowed_tokens is not None:
|
||||||
|
mask[i, allowed_tokens] = 0
|
||||||
logits[i] += mask[i]
|
logits[i] += mask[i]
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user