mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: improve grammar init
This commit is contained in:
parent
13e07b8257
commit
d0d7cd9e92
@ -476,22 +476,22 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
fsm_state: DefaultDict[int, int]
|
||||
fsm: RegexFSM
|
||||
|
||||
def __init__(self, tokenizer, device):
|
||||
def __init__(self, tokenizer, device, grammar):
|
||||
self.device = device
|
||||
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
|
||||
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||
self, grammar, self.tokenizer
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
_input_ids: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
fsm_grammar_state: int,
|
||||
grammar: str,
|
||||
):
|
||||
if fsm_grammar_state == -1 or grammar == "":
|
||||
if fsm_grammar_state == -1 or self.fsm is None:
|
||||
return logits
|
||||
|
||||
fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
|
||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state)
|
||||
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||
mask[allowed_tokens] = 0
|
||||
biased_scores = logits + mask
|
||||
@ -517,10 +517,11 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
fsm = RegexFSM(schema, tokenizer)
|
||||
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||
return fsm
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=32, typed=True)
|
||||
def adapt_tokenizer(tokenizer):
|
||||
"""Adapt tokenizer to work with the FSM.
|
||||
|
||||
@ -529,6 +530,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
Llama's tokenizer to be able to compile FSMs for this model.
|
||||
|
||||
"""
|
||||
start_time = time.time()
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
@ -544,34 +546,40 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
return string
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
|
||||
logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
||||
return tokenizer
|
||||
|
||||
def filter(self, indices, fsm_grammar_states, grammars):
|
||||
def filter(self, indices):
|
||||
new_fsms = []
|
||||
for i in indices:
|
||||
new_fsms.append(self.fsms[i])
|
||||
self.fsms = new_fsms
|
||||
return self
|
||||
|
||||
|
||||
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
def __init__(self, tokenizer, device):
|
||||
def __init__(self, tokenizer, device, grammars):
|
||||
self.device = device
|
||||
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
|
||||
self.fsms = [
|
||||
(
|
||||
GrammarLogitProcessor._cached_compile_fsm(self, g, self.tokenizer)
|
||||
if g
|
||||
else None
|
||||
)
|
||||
for g in grammars
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
_input_ids: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
fsm_grammar_states: List[int],
|
||||
grammars: List[str],
|
||||
):
|
||||
for i in range(logits.shape[0]):
|
||||
if fsm_grammar_states[i] == -1 or grammars[i] == "":
|
||||
fsm = self.fsms[i]
|
||||
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||
continue
|
||||
|
||||
fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||
self, grammars[i], self.tokenizer
|
||||
)
|
||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||
|
||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||
mask[allowed_tokens] = 0
|
||||
biased_scores = logits[i] + mask
|
||||
@ -582,3 +590,6 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
return GrammarLogitProcessor.advance(
|
||||
self, next_token_ids, fsm_grammar_states, grammars
|
||||
)
|
||||
|
||||
def filter(self, indices):
|
||||
return GrammarLogitProcessor.filter(self, indices)
|
||||
|
@ -52,7 +52,7 @@ class NextTokenChooser:
|
||||
else None
|
||||
)
|
||||
self.grammar_processor = (
|
||||
GrammarLogitProcessor(tokenizer, device) if grammar != "" else None
|
||||
GrammarLogitProcessor(tokenizer, device, grammar) if grammar != "" else None
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@ -83,9 +83,7 @@ class NextTokenChooser:
|
||||
if self.frequency_processor is not None:
|
||||
scores = self.frequency_processor(input_ids, scores)
|
||||
if self.grammar_processor is not None:
|
||||
scores = self.grammar_processor(
|
||||
input_ids, scores, self.fsm_grammar_state, self.grammar
|
||||
)
|
||||
scores = self.grammar_processor(scores, self.fsm_grammar_state)
|
||||
|
||||
if self.static_warper is None:
|
||||
next_logprob = torch.log_softmax(scores, -1)
|
||||
@ -261,8 +259,8 @@ class HeterogeneousNextTokenChooser:
|
||||
)
|
||||
|
||||
self.grammar_processor = (
|
||||
HeterogeneousGrammarLogitProcessor(tokenizer, device)
|
||||
if any([grammar != "" and grammar is not None for grammar in grammars])
|
||||
HeterogeneousGrammarLogitProcessor(tokenizer, device, grammars)
|
||||
if any([grammar != "" for grammar in grammars])
|
||||
else None
|
||||
)
|
||||
|
||||
@ -331,9 +329,7 @@ class HeterogeneousNextTokenChooser:
|
||||
for warper in self.warpers:
|
||||
_scores = warper(input_ids, _scores)
|
||||
if self.grammar_processor is not None:
|
||||
_scores = self.grammar_processor(
|
||||
input_ids, _scores, self.fsm_grammar_states, self.grammars
|
||||
)
|
||||
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
||||
_next_ids = self.choice(_scores)
|
||||
scores[:, j] = _scores
|
||||
next_ids[:, j] = _next_ids
|
||||
@ -408,6 +404,9 @@ class HeterogeneousNextTokenChooser:
|
||||
if self.frequency_processor is not None:
|
||||
self.frequency_processor = self.frequency_processor.filter(indices)
|
||||
|
||||
if self.grammar_processor is not None:
|
||||
self.grammar_processor = self.grammar_processor.filter(indices)
|
||||
|
||||
filtered_warpers = []
|
||||
for warper in self.warpers:
|
||||
filtered_warper = warper.filter(indices)
|
||||
|
Loading…
Reference in New Issue
Block a user