From c791187b0254c8f9db453f99ccaf9baaf01767a3 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 14 Feb 2024 17:18:47 +0000 Subject: [PATCH] feat: support spectulative decoding grammar advances --- server/text_generation_server/models/flash_causal_lm.py | 7 ++++++- server/text_generation_server/utils/logits_process.py | 5 +++++ server/text_generation_server/utils/tokens.py | 8 ++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 25eacf64..7ec8c2fc 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1054,7 +1054,6 @@ class FlashCausalLM(Model): next_token_ids = next_input_ids.tolist() accepted_ids = accepted_ids.tolist() start_decode = time.time_ns() - batch.next_token_chooser = batch.next_token_chooser.advance_grammar(next_token_ids) # Zipped iterator iterator = zip( @@ -1211,6 +1210,12 @@ class FlashCausalLM(Model): generations.append(generation) + # accept each new token for this specific request since we may + # have more than one new token per request with speculative decoding + for next_token_id in _next_token_ids: + batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id) + + # Update values batch.input_lengths[i] = input_length + n_accepted_ids if batch.input_lengths[i] > batch.max_seqlen: diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index ddc151da..73fcf53f 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -593,5 +593,10 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): for i in range(len(next_token_ids)) ] + def advance_at_index(self, next_token_id, fsm_grammar_state, index): + return GrammarLogitProcessor._advance( + next_token_id, fsm_grammar_state, self.fsms[index] + ) + def filter(self, indices): return GrammarLogitProcessor.filter(self, indices) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 896d4d40..2784585e 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -307,6 +307,7 @@ class HeterogeneousNextTokenChooser: self.tokenizer = tokenizer self.fsm_grammar_states = fsm_grammar_states self.grammars = grammars + self.grammar_types = grammar_types def __call__( self, @@ -406,6 +407,13 @@ class HeterogeneousNextTokenChooser: self.fsm_grammar_states = other_new_states return self + def advance_grammar_single(self, grammar_state_index: int, next_id: int): + if self.grammar_processor is not None: + self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index( + next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index + ) + return self + def filter(self, indices): if self.watermark_processor is not None: self.watermark_processor = self.watermark_processor.filter(indices)