feat: improve grammar init

This commit is contained in:
drbh 2024-02-12 18:42:09 +00:00
parent 13e07b8257
commit d0d7cd9e92
2 changed files with 37 additions and 27 deletions

View File

@ -476,22 +476,22 @@ class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int]
fsm: RegexFSM
def __init__(self, tokenizer, device):
def __init__(self, tokenizer, device, grammar):
self.device = device
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
self, grammar, self.tokenizer
)
def __call__(
self,
_input_ids: torch.Tensor,
logits: torch.Tensor,
fsm_grammar_state: int,
grammar: str,
):
if fsm_grammar_state == -1 or grammar == "":
if fsm_grammar_state == -1 or self.fsm is None:
return logits
fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state)
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
biased_scores = logits + mask
@ -517,10 +517,11 @@ class GrammarLogitProcessor(LogitsProcessor):
except json.JSONDecodeError:
pass
fsm = RegexFSM(schema, tokenizer)
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
@staticmethod
@lru_cache(maxsize=32, typed=True)
def adapt_tokenizer(tokenizer):
"""Adapt tokenizer to work with the FSM.
@ -529,6 +530,7 @@ class GrammarLogitProcessor(LogitsProcessor):
Llama's tokenizer to be able to compile FSMs for this model.
"""
start_time = time.time()
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
@ -544,34 +546,40 @@ class GrammarLogitProcessor(LogitsProcessor):
return string
tokenizer.convert_token_to_string = convert_token_to_string
logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
return tokenizer
def filter(self, indices, fsm_grammar_states, grammars):
def filter(self, indices):
new_fsms = []
for i in indices:
new_fsms.append(self.fsms[i])
self.fsms = new_fsms
return self
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device):
def __init__(self, tokenizer, device, grammars):
self.device = device
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
self.fsms = [
(
GrammarLogitProcessor._cached_compile_fsm(self, g, self.tokenizer)
if g
else None
)
for g in grammars
]
def __call__(
self,
_input_ids: torch.Tensor,
logits: torch.Tensor,
fsm_grammar_states: List[int],
grammars: List[str],
):
for i in range(logits.shape[0]):
if fsm_grammar_states[i] == -1 or grammars[i] == "":
fsm = self.fsms[i]
if fsm_grammar_states[i] == -1 or fsm is None:
continue
fsm = GrammarLogitProcessor._cached_compile_fsm(
self, grammars[i], self.tokenizer
)
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
biased_scores = logits[i] + mask
@ -582,3 +590,6 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
return GrammarLogitProcessor.advance(
self, next_token_ids, fsm_grammar_states, grammars
)
def filter(self, indices):
return GrammarLogitProcessor.filter(self, indices)

View File

@ -52,7 +52,7 @@ class NextTokenChooser:
else None
)
self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device) if grammar != "" else None
GrammarLogitProcessor(tokenizer, device, grammar) if grammar != "" else None
)
self.tokenizer = tokenizer
@ -83,9 +83,7 @@ class NextTokenChooser:
if self.frequency_processor is not None:
scores = self.frequency_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(
input_ids, scores, self.fsm_grammar_state, self.grammar
)
scores = self.grammar_processor(scores, self.fsm_grammar_state)
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
@ -261,8 +259,8 @@ class HeterogeneousNextTokenChooser:
)
self.grammar_processor = (
HeterogeneousGrammarLogitProcessor(tokenizer, device)
if any([grammar != "" and grammar is not None for grammar in grammars])
HeterogeneousGrammarLogitProcessor(tokenizer, device, grammars)
if any([grammar != "" for grammar in grammars])
else None
)
@ -331,9 +329,7 @@ class HeterogeneousNextTokenChooser:
for warper in self.warpers:
_scores = warper(input_ids, _scores)
if self.grammar_processor is not None:
_scores = self.grammar_processor(
input_ids, _scores, self.fsm_grammar_states, self.grammars
)
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids
@ -408,6 +404,9 @@ class HeterogeneousNextTokenChooser:
if self.frequency_processor is not None:
self.frequency_processor = self.frequency_processor.filter(indices)
if self.grammar_processor is not None:
self.grammar_processor = self.grammar_processor.filter(indices)
filtered_warpers = []
for warper in self.warpers:
filtered_warper = warper.filter(indices)