From 8fd2664a3cd9e7b061b10b687b5a76d0bffe0b88 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 8 Feb 2024 19:56:16 +0000 Subject: [PATCH] feat: support other models and add fsm caching --- .../models/flash_causal_lm.py | 4 +- .../models/flash_mistral.py | 2 +- server/text_generation_server/models/mamba.py | 2 +- server/text_generation_server/utils/tokens.py | 37 ++++++++++++------- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 886fe486..12be2485 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -237,7 +237,7 @@ class FlashCausalLMBatch(Batch): ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device + next_token_chooser_parameters, dtype, device, tokenizer ) start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -593,6 +593,8 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters, dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, + tokenizer=batches[0].next_token_chooser.tokenizer, + grammar=batches[0].requests.parameters.grammar, ) speculative_ids = ( diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 34a50194..70669c8d 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -192,7 +192,7 @@ class FlashMistralBatch(FlashCausalLMBatch): ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device + next_token_chooser_parameters, dtype, device, tokenizer ) start_slots = torch.tensor(start_slots, dtype=torch.int64) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 868db6aa..774b45c0 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -124,7 +124,7 @@ class MambaBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 802b5f88..b4abeb9b 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -22,6 +22,7 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess from outlines.fsm.fsm import RegexFSM from outlines.fsm.json_schema import build_regex_from_object +from functools import lru_cache # TODO: remove when done debugging import time @@ -219,6 +220,8 @@ class HeterogeneousNextTokenChooser: typical_p: List[float], do_sample: List[bool], seeds: List[int], + tokenizer=None, + grammar=None, ): warpers = [] @@ -272,11 +275,15 @@ class HeterogeneousNextTokenChooser: self.warpers = warpers - if any(do_sample): + first_grammar = grammar[0] if grammar else None + if first_grammar: + self.choice = Grammar(tokenizer, device, first_grammar) + elif any(do_sample): self.choice = HeterogeneousSampling(do_sample, seeds, device) else: self.choice = Greedy() + self.use_grammar = grammar is not None self.seeds = seeds self.do_sample = do_sample self.dtype = dtype @@ -390,7 +397,7 @@ class HeterogeneousNextTokenChooser: self.seeds = [self.seeds[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices] - if any(self.do_sample): + if self.use_grammar or any(self.do_sample): self.choice.filter(indices) else: self.choice = Greedy() @@ -403,6 +410,7 @@ class HeterogeneousNextTokenChooser: pb: List[generate_pb2.NextTokenChooserParameters], dtype: torch.dtype, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], @@ -416,6 +424,8 @@ class HeterogeneousNextTokenChooser: seeds=[pb_.seed for pb_ in pb], device=device, dtype=dtype, + tokenizer=tokenizer, + grammar=[pb_.grammar for pb_ in pb], ) @@ -443,22 +453,12 @@ class Grammar: fsm: RegexFSM def __init__(self, tokenizer, device, grammar): - # TODO: remove debug logs - start_time = time.time() - tokenizer = self.adapt_tokenizer(tokenizer) - print(f"Adapt tokenizer: {time.time() - start_time}") - start_time = time.time() - regex_string = build_regex_from_object(grammar) - print(f"Build regex: {time.time() - start_time}") - fsm = RegexFSM(regex_string, tokenizer) - print(f"Compile FSM: {time.time() - start_time}") - + fsm = self.compile_fsm(grammar, tokenizer) self.fsm = fsm self.fsm_state = defaultdict(int) self.device = device def __call__(self, logits): - # TODO: handle seq_id properly seq_id = 0 if self.fsm_state[seq_id] == -1: @@ -477,6 +477,17 @@ class Grammar: self.fsm_state[seq_id], greedy.item() ) return greedy + + @lru_cache(maxsize=32, typed=True) + def compile_fsm(self, schema, tokenizer): + start_time = time.time() + tokenizer = self.adapt_tokenizer(tokenizer) + is_json_string = schema.startswith("{") and schema.endswith("}") + regex_string = build_regex_from_object(schema) if is_json_string else schema + fsm = RegexFSM(regex_string, tokenizer) + print(f"Compile FSM: {time.time() - start_time}") + return fsm + def adapt_tokenizer(self, tokenizer): """Adapt tokenizer to work with the FSM.