From 3df37fa9412c4716d2fe39322b2b805387268e3c Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 13 Feb 2024 17:26:01 +0000 Subject: [PATCH] feat: fix advance_grammar sig, add comment and move advance call --- server/text_generation_server/models/flash_causal_lm.py | 6 +++--- server/text_generation_server/utils/logits_process.py | 5 +++++ server/text_generation_server/utils/tokens.py | 6 +++--- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a286e41c..0b455fec 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1034,9 +1034,6 @@ class FlashCausalLM(Model): cumulative_length += input_length # Update values - batch.next_token_chooser = batch.next_token_chooser.advance_grammar( - next_input_ids - ) batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids @@ -1053,6 +1050,9 @@ class FlashCausalLM(Model): prefill_logprobs = prefill_logprobs.view(-1).tolist() # GPU <-> CPU sync + batch.next_token_chooser = batch.next_token_chooser.advance_grammar( + next_input_ids.tolist(), + ) next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() accepted_ids = accepted_ids.tolist() diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index cd43e2a4..e88e8a74 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -505,10 +505,15 @@ class GrammarLogitProcessor(LogitsProcessor): return fsm_grammar_state return fsm.next_state(fsm_grammar_state, next_token_id) + # TODO: move grammar compilation into the router @staticmethod @lru_cache(maxsize=32, typed=True) def _cached_compile_fsm(schema, tokenizer): start_time = time.time() + # Detect if schema is a json object before converting it to regex. + # We need to check if it's a valid json object before converting it to regex + # and cannot simply test if it starts with '{' and ends with '}' because there + # are valid regexes that start and end with curly braces. try: json.loads(schema) # check if schema is a valid json schema = build_regex_from_object(schema) # convert schema to regex diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 70bde1d1..8d268c65 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -94,7 +94,7 @@ class NextTokenChooser: return next_id, next_logprob - def advance_grammar(self, next_id): + def advance_grammar(self, next_id: int): if self.grammar_processor is not None: self.fsm_grammar_state = self.grammar_processor.advance( next_id, self.fsm_grammar_state @@ -388,10 +388,10 @@ class HeterogeneousNextTokenChooser: return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids - def advance_grammar(self, next_ids: torch.Tensor): + def advance_grammar(self, next_ids: List[int]): if self.grammar_processor is not None: other_new_states = self.grammar_processor.advance_batch( - next_ids.tolist(), self.fsm_grammar_states, self.grammars + next_ids, self.fsm_grammar_states, self.grammars ) self.fsm_grammar_states = other_new_states return self