mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve grammar init
This commit is contained in:
parent
13e07b8257
commit
d0d7cd9e92
@ -476,22 +476,22 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
fsm_state: DefaultDict[int, int]
|
fsm_state: DefaultDict[int, int]
|
||||||
fsm: RegexFSM
|
fsm: RegexFSM
|
||||||
|
|
||||||
def __init__(self, tokenizer, device):
|
def __init__(self, tokenizer, device, grammar):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
|
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
|
||||||
|
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||||
|
self, grammar, self.tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
_input_ids: torch.Tensor,
|
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
fsm_grammar_state: int,
|
fsm_grammar_state: int,
|
||||||
grammar: str,
|
|
||||||
):
|
):
|
||||||
if fsm_grammar_state == -1 or grammar == "":
|
if fsm_grammar_state == -1 or self.fsm is None:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
|
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
||||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state)
|
|
||||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||||
mask[allowed_tokens] = 0
|
mask[allowed_tokens] = 0
|
||||||
biased_scores = logits + mask
|
biased_scores = logits + mask
|
||||||
@ -517,10 +517,11 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
fsm = RegexFSM(schema, tokenizer)
|
fsm = RegexFSM(schema, tokenizer)
|
||||||
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||||
return fsm
|
return fsm
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@lru_cache(maxsize=32, typed=True)
|
||||||
def adapt_tokenizer(tokenizer):
|
def adapt_tokenizer(tokenizer):
|
||||||
"""Adapt tokenizer to work with the FSM.
|
"""Adapt tokenizer to work with the FSM.
|
||||||
|
|
||||||
@ -529,6 +530,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
Llama's tokenizer to be able to compile FSMs for this model.
|
Llama's tokenizer to be able to compile FSMs for this model.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
start_time = time.time()
|
||||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||||
|
|
||||||
@ -544,34 +546,40 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
return string
|
return string
|
||||||
|
|
||||||
tokenizer.convert_token_to_string = convert_token_to_string
|
tokenizer.convert_token_to_string = convert_token_to_string
|
||||||
|
logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
def filter(self, indices, fsm_grammar_states, grammars):
|
def filter(self, indices):
|
||||||
|
new_fsms = []
|
||||||
|
for i in indices:
|
||||||
|
new_fsms.append(self.fsms[i])
|
||||||
|
self.fsms = new_fsms
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
def __init__(self, tokenizer, device):
|
def __init__(self, tokenizer, device, grammars):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
|
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
|
||||||
|
self.fsms = [
|
||||||
|
(
|
||||||
|
GrammarLogitProcessor._cached_compile_fsm(self, g, self.tokenizer)
|
||||||
|
if g
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
for g in grammars
|
||||||
|
]
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
_input_ids: torch.Tensor,
|
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
fsm_grammar_states: List[int],
|
fsm_grammar_states: List[int],
|
||||||
grammars: List[str],
|
|
||||||
):
|
):
|
||||||
for i in range(logits.shape[0]):
|
for i in range(logits.shape[0]):
|
||||||
if fsm_grammar_states[i] == -1 or grammars[i] == "":
|
fsm = self.fsms[i]
|
||||||
|
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
fsm = GrammarLogitProcessor._cached_compile_fsm(
|
|
||||||
self, grammars[i], self.tokenizer
|
|
||||||
)
|
|
||||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||||
|
|
||||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||||
mask[allowed_tokens] = 0
|
mask[allowed_tokens] = 0
|
||||||
biased_scores = logits[i] + mask
|
biased_scores = logits[i] + mask
|
||||||
@ -582,3 +590,6 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||||||
return GrammarLogitProcessor.advance(
|
return GrammarLogitProcessor.advance(
|
||||||
self, next_token_ids, fsm_grammar_states, grammars
|
self, next_token_ids, fsm_grammar_states, grammars
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
return GrammarLogitProcessor.filter(self, indices)
|
||||||
|
@ -52,7 +52,7 @@ class NextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.grammar_processor = (
|
self.grammar_processor = (
|
||||||
GrammarLogitProcessor(tokenizer, device) if grammar != "" else None
|
GrammarLogitProcessor(tokenizer, device, grammar) if grammar != "" else None
|
||||||
)
|
)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
@ -83,9 +83,7 @@ class NextTokenChooser:
|
|||||||
if self.frequency_processor is not None:
|
if self.frequency_processor is not None:
|
||||||
scores = self.frequency_processor(input_ids, scores)
|
scores = self.frequency_processor(input_ids, scores)
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
scores = self.grammar_processor(
|
scores = self.grammar_processor(scores, self.fsm_grammar_state)
|
||||||
input_ids, scores, self.fsm_grammar_state, self.grammar
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.static_warper is None:
|
if self.static_warper is None:
|
||||||
next_logprob = torch.log_softmax(scores, -1)
|
next_logprob = torch.log_softmax(scores, -1)
|
||||||
@ -261,8 +259,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.grammar_processor = (
|
self.grammar_processor = (
|
||||||
HeterogeneousGrammarLogitProcessor(tokenizer, device)
|
HeterogeneousGrammarLogitProcessor(tokenizer, device, grammars)
|
||||||
if any([grammar != "" and grammar is not None for grammar in grammars])
|
if any([grammar != "" for grammar in grammars])
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -331,9 +329,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
_scores = warper(input_ids, _scores)
|
_scores = warper(input_ids, _scores)
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
_scores = self.grammar_processor(
|
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
||||||
input_ids, _scores, self.fsm_grammar_states, self.grammars
|
|
||||||
)
|
|
||||||
_next_ids = self.choice(_scores)
|
_next_ids = self.choice(_scores)
|
||||||
scores[:, j] = _scores
|
scores[:, j] = _scores
|
||||||
next_ids[:, j] = _next_ids
|
next_ids[:, j] = _next_ids
|
||||||
@ -408,6 +404,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
if self.frequency_processor is not None:
|
if self.frequency_processor is not None:
|
||||||
self.frequency_processor = self.frequency_processor.filter(indices)
|
self.frequency_processor = self.frequency_processor.filter(indices)
|
||||||
|
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
self.grammar_processor = self.grammar_processor.filter(indices)
|
||||||
|
|
||||||
filtered_warpers = []
|
filtered_warpers = []
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
filtered_warper = warper.filter(indices)
|
filtered_warper = warper.filter(indices)
|
||||||
|
Loading…
Reference in New Issue
Block a user