mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support spectulative decoding grammar advances
This commit is contained in:
parent
63c52fb22d
commit
c791187b02
@ -1054,7 +1054,6 @@ class FlashCausalLM(Model):
|
|||||||
next_token_ids = next_input_ids.tolist()
|
next_token_ids = next_input_ids.tolist()
|
||||||
accepted_ids = accepted_ids.tolist()
|
accepted_ids = accepted_ids.tolist()
|
||||||
start_decode = time.time_ns()
|
start_decode = time.time_ns()
|
||||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(next_token_ids)
|
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
@ -1211,6 +1210,12 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
|
# accept each new token for this specific request since we may
|
||||||
|
# have more than one new token per request with speculative decoding
|
||||||
|
for next_token_id in _next_token_ids:
|
||||||
|
batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id)
|
||||||
|
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_lengths[i] = input_length + n_accepted_ids
|
batch.input_lengths[i] = input_length + n_accepted_ids
|
||||||
if batch.input_lengths[i] > batch.max_seqlen:
|
if batch.input_lengths[i] > batch.max_seqlen:
|
||||||
|
@ -593,5 +593,10 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||||||
for i in range(len(next_token_ids))
|
for i in range(len(next_token_ids))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def advance_at_index(self, next_token_id, fsm_grammar_state, index):
|
||||||
|
return GrammarLogitProcessor._advance(
|
||||||
|
next_token_id, fsm_grammar_state, self.fsms[index]
|
||||||
|
)
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
return GrammarLogitProcessor.filter(self, indices)
|
return GrammarLogitProcessor.filter(self, indices)
|
||||||
|
@ -307,6 +307,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.fsm_grammar_states = fsm_grammar_states
|
self.fsm_grammar_states = fsm_grammar_states
|
||||||
self.grammars = grammars
|
self.grammars = grammars
|
||||||
|
self.grammar_types = grammar_types
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -406,6 +407,13 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.fsm_grammar_states = other_new_states
|
self.fsm_grammar_states = other_new_states
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index(
|
||||||
|
next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
self.watermark_processor = self.watermark_processor.filter(indices)
|
self.watermark_processor = self.watermark_processor.filter(indices)
|
||||||
|
Loading…
Reference in New Issue
Block a user