diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index cc9a39ac..613ec8b9 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -476,22 +476,22 @@ class GrammarLogitProcessor(LogitsProcessor): fsm_state: DefaultDict[int, int] fsm: RegexFSM - def __init__(self, tokenizer, device): + def __init__(self, tokenizer, device, grammar): self.device = device self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer) + self.fsm = GrammarLogitProcessor._cached_compile_fsm( + self, grammar, self.tokenizer + ) def __call__( self, - _input_ids: torch.Tensor, logits: torch.Tensor, fsm_grammar_state: int, - grammar: str, ): - if fsm_grammar_state == -1 or grammar == "": + if fsm_grammar_state == -1 or self.fsm is None: return logits - fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer) - allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state) + allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) mask[allowed_tokens] = 0 biased_scores = logits + mask @@ -517,10 +517,11 @@ class GrammarLogitProcessor(LogitsProcessor): except json.JSONDecodeError: pass fsm = RegexFSM(schema, tokenizer) - logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s") + logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s") return fsm @staticmethod + @lru_cache(maxsize=32, typed=True) def adapt_tokenizer(tokenizer): """Adapt tokenizer to work with the FSM. @@ -529,6 +530,7 @@ class GrammarLogitProcessor(LogitsProcessor): Llama's tokenizer to be able to compile FSMs for this model. """ + start_time = time.time() tokenizer.vocabulary = tokenizer.get_vocab() tokenizer.special_tokens = set(tokenizer.all_special_tokens) @@ -544,34 +546,40 @@ class GrammarLogitProcessor(LogitsProcessor): return string tokenizer.convert_token_to_string = convert_token_to_string - + logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s") return tokenizer - def filter(self, indices, fsm_grammar_states, grammars): + def filter(self, indices): + new_fsms = [] + for i in indices: + new_fsms.append(self.fsms[i]) + self.fsms = new_fsms return self class HeterogeneousGrammarLogitProcessor(LogitsProcessor): - def __init__(self, tokenizer, device): + def __init__(self, tokenizer, device, grammars): self.device = device self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer) + self.fsms = [ + ( + GrammarLogitProcessor._cached_compile_fsm(self, g, self.tokenizer) + if g + else None + ) + for g in grammars + ] def __call__( self, - _input_ids: torch.Tensor, logits: torch.Tensor, fsm_grammar_states: List[int], - grammars: List[str], ): for i in range(logits.shape[0]): - if fsm_grammar_states[i] == -1 or grammars[i] == "": + fsm = self.fsms[i] + if fsm_grammar_states[i] == -1 or fsm is None: continue - - fsm = GrammarLogitProcessor._cached_compile_fsm( - self, grammars[i], self.tokenizer - ) allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) - mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) mask[allowed_tokens] = 0 biased_scores = logits[i] + mask @@ -582,3 +590,6 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): return GrammarLogitProcessor.advance( self, next_token_ids, fsm_grammar_states, grammars ) + + def filter(self, indices): + return GrammarLogitProcessor.filter(self, indices) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 6b1adce8..360e4fe3 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -52,7 +52,7 @@ class NextTokenChooser: else None ) self.grammar_processor = ( - GrammarLogitProcessor(tokenizer, device) if grammar != "" else None + GrammarLogitProcessor(tokenizer, device, grammar) if grammar != "" else None ) self.tokenizer = tokenizer @@ -83,9 +83,7 @@ class NextTokenChooser: if self.frequency_processor is not None: scores = self.frequency_processor(input_ids, scores) if self.grammar_processor is not None: - scores = self.grammar_processor( - input_ids, scores, self.fsm_grammar_state, self.grammar - ) + scores = self.grammar_processor(scores, self.fsm_grammar_state) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -261,8 +259,8 @@ class HeterogeneousNextTokenChooser: ) self.grammar_processor = ( - HeterogeneousGrammarLogitProcessor(tokenizer, device) - if any([grammar != "" and grammar is not None for grammar in grammars]) + HeterogeneousGrammarLogitProcessor(tokenizer, device, grammars) + if any([grammar != "" for grammar in grammars]) else None ) @@ -331,9 +329,7 @@ class HeterogeneousNextTokenChooser: for warper in self.warpers: _scores = warper(input_ids, _scores) if self.grammar_processor is not None: - _scores = self.grammar_processor( - input_ids, _scores, self.fsm_grammar_states, self.grammars - ) + _scores = self.grammar_processor(_scores, self.fsm_grammar_states) _next_ids = self.choice(_scores) scores[:, j] = _scores next_ids[:, j] = _next_ids @@ -408,6 +404,9 @@ class HeterogeneousNextTokenChooser: if self.frequency_processor is not None: self.frequency_processor = self.frequency_processor.filter(indices) + if self.grammar_processor is not None: + self.grammar_processor = self.grammar_processor.filter(indices) + filtered_warpers = [] for warper in self.warpers: filtered_warper = warper.filter(indices)