mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support spectulative decoding grammar advances
This commit is contained in:
parent
63c52fb22d
commit
c791187b02
@ -1054,7 +1054,6 @@ class FlashCausalLM(Model):
|
||||
next_token_ids = next_input_ids.tolist()
|
||||
accepted_ids = accepted_ids.tolist()
|
||||
start_decode = time.time_ns()
|
||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(next_token_ids)
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
@ -1211,6 +1210,12 @@ class FlashCausalLM(Model):
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
# accept each new token for this specific request since we may
|
||||
# have more than one new token per request with speculative decoding
|
||||
for next_token_id in _next_token_ids:
|
||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id)
|
||||
|
||||
|
||||
# Update values
|
||||
batch.input_lengths[i] = input_length + n_accepted_ids
|
||||
if batch.input_lengths[i] > batch.max_seqlen:
|
||||
|
@ -593,5 +593,10 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
for i in range(len(next_token_ids))
|
||||
]
|
||||
|
||||
def advance_at_index(self, next_token_id, fsm_grammar_state, index):
|
||||
return GrammarLogitProcessor._advance(
|
||||
next_token_id, fsm_grammar_state, self.fsms[index]
|
||||
)
|
||||
|
||||
def filter(self, indices):
|
||||
return GrammarLogitProcessor.filter(self, indices)
|
||||
|
@ -307,6 +307,7 @@ class HeterogeneousNextTokenChooser:
|
||||
self.tokenizer = tokenizer
|
||||
self.fsm_grammar_states = fsm_grammar_states
|
||||
self.grammars = grammars
|
||||
self.grammar_types = grammar_types
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -406,6 +407,13 @@ class HeterogeneousNextTokenChooser:
|
||||
self.fsm_grammar_states = other_new_states
|
||||
return self
|
||||
|
||||
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
|
||||
if self.grammar_processor is not None:
|
||||
self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index(
|
||||
next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index
|
||||
)
|
||||
return self
|
||||
|
||||
def filter(self, indices):
|
||||
if self.watermark_processor is not None:
|
||||
self.watermark_processor = self.watermark_processor.filter(indices)
|
||||
|
Loading…
Reference in New Issue
Block a user