feat: support spectulative decoding grammar advances

This commit is contained in:
drbh 2024-02-14 17:18:47 +00:00
parent 63c52fb22d
commit c791187b02
3 changed files with 19 additions and 1 deletions

View File

@ -1054,7 +1054,6 @@ class FlashCausalLM(Model):
next_token_ids = next_input_ids.tolist() next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist() accepted_ids = accepted_ids.tolist()
start_decode = time.time_ns() start_decode = time.time_ns()
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(next_token_ids)
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -1211,6 +1210,12 @@ class FlashCausalLM(Model):
generations.append(generation) generations.append(generation)
# accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids:
batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id)
# Update values # Update values
batch.input_lengths[i] = input_length + n_accepted_ids batch.input_lengths[i] = input_length + n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen: if batch.input_lengths[i] > batch.max_seqlen:

View File

@ -593,5 +593,10 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
for i in range(len(next_token_ids)) for i in range(len(next_token_ids))
] ]
def advance_at_index(self, next_token_id, fsm_grammar_state, index):
return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsms[index]
)
def filter(self, indices): def filter(self, indices):
return GrammarLogitProcessor.filter(self, indices) return GrammarLogitProcessor.filter(self, indices)

View File

@ -307,6 +307,7 @@ class HeterogeneousNextTokenChooser:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.fsm_grammar_states = fsm_grammar_states self.fsm_grammar_states = fsm_grammar_states
self.grammars = grammars self.grammars = grammars
self.grammar_types = grammar_types
def __call__( def __call__(
self, self,
@ -406,6 +407,13 @@ class HeterogeneousNextTokenChooser:
self.fsm_grammar_states = other_new_states self.fsm_grammar_states = other_new_states
return self return self
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
if self.grammar_processor is not None:
self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index(
next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index
)
return self
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
self.watermark_processor = self.watermark_processor.filter(indices) self.watermark_processor = self.watermark_processor.filter(indices)