From f1d43f2df45e9cda4eb6569b462d7b791b4ddf16 Mon Sep 17 00:00:00 2001 From: drbh Date: Sat, 10 Feb 2024 01:41:22 +0000 Subject: [PATCH] fix: prefer grammar as logit processor --- .../models/flash_causal_lm.py | 5 - .../utils/logits_process.py | 154 ++++++++++++++- server/text_generation_server/utils/tokens.py | 179 ++++++------------ 3 files changed, 211 insertions(+), 127 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c5ce89be..bd91b0e2 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -99,9 +99,6 @@ class FlashCausalLMBatch(Batch): # Maximum number of blocks max_blocks: int - # The states for the grammar FSM - fsm_states: Dict[int, int] = None - def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, @@ -140,7 +137,6 @@ class FlashCausalLMBatch(Batch): read_offsets = [] all_input_ids = [] requests_idx_mapping = {} - fsm_states = {} all_prefill_logprobs = True no_prefill_logprobs = True @@ -323,7 +319,6 @@ class FlashCausalLMBatch(Batch): blocks=blocks, max_blocks=max_blocks, speculative_ids=None, - fsm_states=fsm_states, ) @tracer.start_as_current_span("filter") diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 291c522f..3116064c 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -4,6 +4,12 @@ import torch from functools import lru_cache from typing import Optional, List, Dict, Union +from outlines.fsm.fsm import RegexFSM +from outlines.fsm.json_schema import build_regex_from_object +from functools import lru_cache +from typing import List, Optional, DefaultDict +import time + from transformers import ( LogitsWarper, LogitsProcessor, @@ -135,9 +141,7 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor): ) -> torch.FloatTensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then penalty has to be multiplied to reduce the previous token probability - score = -torch.where( - score < 0, score * self.penalty, score / self.penalty - ) + score = -torch.where(score < 0, score * self.penalty, score / self.penalty) return scores.scatter_add_(1, input_ids, score) @@ -464,3 +468,147 @@ class HeterogeneousProcessorWrapper(LogitsProcessor): self.processors = new_processors return self return None + + +class GrammarLogitProcessor(LogitsProcessor): + fsm_state: DefaultDict[int, int] + fsm: RegexFSM + + def __init__(self, tokenizer, device): + self.device = device + self.tokenizer = tokenizer + + def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_state, grammar): + if fsm_grammar_state == -1: + # todo mask for only eos token + return logits + + # if grammar is '' or None, return the greedy token + if grammar == "" or grammar is None: + return logits + + fsm = self.compile_fsm(grammar, self.tokenizer) + allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state) + mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) + mask[allowed_tokens] = 0 + biased_scores = logits + mask + return biased_scores + + def advance(self, next_token_id, fsm_grammar_state, grammar): + if fsm_grammar_state == -1: + return fsm_grammar_state + + if grammar == "" or grammar is None: + return fsm_grammar_state + + fsm = self.compile_fsm(grammar, self.tokenizer) + return fsm.next_state(fsm_grammar_state, next_token_id) + + @lru_cache(maxsize=32, typed=True) + def compile_fsm(self, schema, tokenizer): + start_time = time.time() + tokenizer = self.adapt_tokenizer(tokenizer) + is_json_string = schema.startswith("{") and schema.endswith("}") + regex_string = build_regex_from_object(schema) if is_json_string else schema + fsm = RegexFSM(regex_string, tokenizer) + print(f"Compile FSM: {time.time() - start_time}") + return fsm + + def adapt_tokenizer(self, tokenizer): + """Adapt tokenizer to work with the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer + + def filter(self, indices, fsm_grammar_states, grammars): + return self + + +class HeterogeneousGrammarLogitProcessor(LogitsProcessor): + def __init__(self, tokenizer, device): + self.device = device + self.tokenizer = tokenizer + + def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars): + for i in range(len(logits)): + if fsm_grammar_states[i] == -1: + # todo mask for only eos token + continue + + # if grammar is '' or None, return the greedy token + if grammars[i] == "" or grammars[i] is None: + continue + + fsm = self.compile_fsm(grammars[i], self.tokenizer) + allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) + mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) + mask[allowed_tokens] = 0 + biased_scores = logits[i] + mask + logits[i] = biased_scores + return logits + + def advance(self, next_token_id, fsm_grammar_state, grammar): + if fsm_grammar_state == -1: + return fsm_grammar_state + + if grammar == "" or grammar is None: + return fsm_grammar_state + + fsm = self.compile_fsm(grammar, self.tokenizer) + return fsm.next_state(fsm_grammar_state, next_token_id) + + @lru_cache(maxsize=32, typed=True) + def compile_fsm(self, schema, tokenizer): + start_time = time.time() + tokenizer = self.adapt_tokenizer(tokenizer) + is_json_string = schema.startswith("{") and schema.endswith("}") + regex_string = build_regex_from_object(schema) if is_json_string else schema + fsm = RegexFSM(regex_string, tokenizer) + # print(f"Compile FSM: {time.time() - start_time}") + return fsm + + def adapt_tokenizer(self, tokenizer): + """Adapt tokenizer to work with the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 67fdcec0..0647b765 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -8,6 +8,7 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.utils.logits_process import ( FrequencyPenaltyLogitsProcessor, + GrammarLogitProcessor, HeterogeneousProcessorWrapper, HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousFrequencyPenaltyLogitsProcessor, @@ -15,6 +16,7 @@ from text_generation_server.utils.logits_process import ( HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, + HeterogeneousGrammarLogitProcessor, static_warper, ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor @@ -58,6 +60,11 @@ class NextTokenChooser: if frequency_penalty and frequency_penalty != 0.0 else None ) + self.grammar_processor = ( + GrammarLogitProcessor(tokenizer, device) + if grammar and grammar != "" + else None + ) self.tokenizer = tokenizer has_warpers = ( @@ -75,14 +82,9 @@ class NextTokenChooser: sampling = do_sample or has_warpers + self.choice = Sampling(seed, device) if sampling else Greedy() self.fsm_grammar_state = fsm_grammar_state - self.grammars = grammar - - # TODO: is grammar a subset of sampling? If so, we should merge them - if grammar: - self.choice = Grammar(tokenizer, device, grammar) - else: - self.choice = Sampling(seed, device) if sampling else Greedy() + self.grammar = grammar def __call__(self, input_ids, scores): if self.watermark_processor is not None: @@ -91,15 +93,23 @@ class NextTokenChooser: scores = self.repetition_processor(input_ids, scores) if self.frequency_processor is not None: scores = self.frequency_processor(input_ids, scores) + if self.grammar_processor is not None: + scores = self.grammar_processor( + input_ids, scores, self.fsm_grammar_state, self.grammar + ) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) else: scores, next_logprob = self.static_warper(scores) - next_id = self.choice(scores[-1], self.fsm_grammar_state, self.grammars).view( - 1, 1 - ) + next_id = self.choice(scores[-1]).view(1, 1) + + if self.grammar_processor is not None: + next_state = self.grammar_processor.advance( + next_id.item(), self.fsm_grammar_state, self.grammar + ) + self.fsm_grammar_state = next_state return next_id, next_logprob @@ -229,7 +239,7 @@ class HeterogeneousNextTokenChooser: do_sample: List[bool], seeds: List[int], tokenizer=None, - grammar=None, + grammars=None, fsm_grammar_states=None, ): warpers = [] @@ -262,6 +272,12 @@ class HeterogeneousNextTokenChooser: else None ) + self.grammar_processor = ( + HeterogeneousGrammarLogitProcessor(tokenizer, device) + if any([grammar != "" and grammar is not None for grammar in grammars]) + else None + ) + if any([x != 1.0 for x in temperature]): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) @@ -284,21 +300,18 @@ class HeterogeneousNextTokenChooser: self.warpers = warpers - if grammar is not None: - self.choice = Grammar(tokenizer, device) - elif any(do_sample): + if any(do_sample): self.choice = HeterogeneousSampling(do_sample, seeds, device) else: self.choice = Greedy() - self.use_grammar = grammar is not None self.seeds = seeds self.do_sample = do_sample self.dtype = dtype self.device = device self.tokenizer = tokenizer self.fsm_grammar_states = fsm_grammar_states - self.grammars = grammar + self.grammars = grammars def __call__( self, @@ -327,11 +340,13 @@ class HeterogeneousNextTokenChooser: _scores = self.repetition_processor(input_ids, _scores) if self.frequency_processor is not None: _scores = self.frequency_processor(input_ids, _scores) - for warper in self.warpers: _scores = warper(input_ids, _scores) - - _next_ids = self.choice(_scores, self.fsm_grammar_states, self.grammars) + if self.grammar_processor is not None: + _scores = self.grammar_processor( + input_ids, _scores, self.fsm_grammar_states, self.grammars + ) + _next_ids = self.choice(_scores) scores[:, j] = _scores next_ids[:, j] = _next_ids next_ids = next_ids.view(B * S) @@ -386,6 +401,19 @@ class HeterogeneousNextTokenChooser: else: speculative_ids = None + # advance the grammar state + if self.grammar_processor is not None: + for i in range(len(self.fsm_grammar_states)): + try: + self.fsm_grammar_states[i] = self.grammar_processor.advance( + next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i] + ) + except: + import ipdb + + ipdb.set_trace() + pass + return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids def filter(self, indices): @@ -408,12 +436,16 @@ class HeterogeneousNextTokenChooser: self.seeds = [self.seeds[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices] - if self.use_grammar or any(self.do_sample): - _, new_fsm_grammar_states, new_grammars = self.choice.filter(indices, self.fsm_grammar_states, self.grammars) - self.fsm_grammar_states = new_fsm_grammar_states - self.grammars = new_grammars - else: - self.choice = Greedy() + new_grammars = [] + new_fsm_grammar_states = [] + for i in indices: + new_grammars.append(self.grammars[i]) + new_fsm_grammar_states.append(self.fsm_grammar_states[i]) + + self.grammars = new_grammars + self.fsm_grammar_states = new_fsm_grammar_states + + self.choice = Greedy() return self @@ -438,7 +470,7 @@ class HeterogeneousNextTokenChooser: device=device, dtype=dtype, tokenizer=tokenizer, - grammar=[pb_.grammar for pb_ in pb], + grammars=[pb_.grammar for pb_ in pb], fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb], ) @@ -449,7 +481,7 @@ class Sampling: self.generator.manual_seed(seed) self.seed = seed - def __call__(self, logits): + def __call__(self, logits, *args): probs = torch.nn.functional.softmax(logits, -1) # Avoid GPU<->CPU sync done by torch multinomial # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 @@ -464,97 +496,6 @@ class Greedy: def filter(self, indices, *args): return self -class Grammar: - fsm_state: DefaultDict[int, int] - fsm: RegexFSM - - def __init__(self, tokenizer, device): - self.device = device - self.tokenizer = tokenizer - - def __call__(self, logits, fsm_grammar_states, grammars): - empty = torch.ones(logits.shape[0], dtype=torch.int64, device=logits.device) - try: - for i in range(len(fsm_grammar_states)): - if fsm_grammar_states[i] == -1: - continue - - # if grammar is '' or None, return the greedy token - if grammars[i] == "" or grammars[i] is None: - empty[i] = logits[i].argmax().item() - continue - - # this is cached and should be fast after the first time - fsm = self.compile_fsm(grammars[i], self.tokenizer) - allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) - mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) - mask[allowed_tokens] = 0 - biased_scores = logits[i : i + 1] + mask - greedy = biased_scores.argmax(dim=-1) - - # if greedy is empty, return the eos token - if greedy.shape[0] == 0: - continue - - # import ipdb; ipdb.set_trace() - fsm_grammar_states[i] = fsm.next_state( - fsm_grammar_states[i], greedy.item() - ) - - empty[i] = greedy.item() - except Exception as e: - print(f"Exception: {e}") - import ipdb - - ipdb.set_trace() - return empty - - @lru_cache(maxsize=32, typed=True) - def compile_fsm(self, schema, tokenizer): - start_time = time.time() - tokenizer = self.adapt_tokenizer(tokenizer) - is_json_string = schema.startswith("{") and schema.endswith("}") - regex_string = build_regex_from_object(schema) if is_json_string else schema - fsm = RegexFSM(regex_string, tokenizer) - print(f"Compile FSM: {time.time() - start_time}") - return fsm - - def adapt_tokenizer(self, tokenizer): - """Adapt tokenizer to work with the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. In addition we need to handle the missing spaces to - Llama's tokenizer to be able to compile FSMs for this model. - - """ - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = tokenizer.convert_tokens_to_string([token]) - - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - tokenizer.convert_token_to_string = convert_token_to_string - - return tokenizer - - def filter(self, indices, fsm_grammar_states, grammars): - new_fsm_grammar_states = [] - new_grammars = [] - - for i in indices: - new_fsm_grammar_states.append(fsm_grammar_states[i]) - new_grammars.append(grammars[i]) - - return self, new_fsm_grammar_states, new_grammars - class HeterogeneousSampling: r""" @@ -574,7 +515,7 @@ class HeterogeneousSampling: self.greedy = Greedy() - def __call__(self, logits, fsm_grammar_states, grammars): + def __call__(self, logits): out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) if self.greedy_indices: # Computing for all indices is faster than slicing