mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: fix advance_grammar sig, add comment and move advance call
This commit is contained in:
parent
d39e45abc3
commit
3df37fa941
@ -1034,9 +1034,6 @@ class FlashCausalLM(Model):
|
|||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
|
|
||||||
next_input_ids
|
|
||||||
)
|
|
||||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
batch.position_ids = next_position_ids + accepted_ids
|
batch.position_ids = next_position_ids + accepted_ids
|
||||||
@ -1053,6 +1050,9 @@ class FlashCausalLM(Model):
|
|||||||
prefill_logprobs = prefill_logprobs.view(-1).tolist()
|
prefill_logprobs = prefill_logprobs.view(-1).tolist()
|
||||||
|
|
||||||
# GPU <-> CPU sync
|
# GPU <-> CPU sync
|
||||||
|
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
|
||||||
|
next_input_ids.tolist(),
|
||||||
|
)
|
||||||
next_token_logprobs = next_token_logprobs.tolist()
|
next_token_logprobs = next_token_logprobs.tolist()
|
||||||
next_token_ids = next_input_ids.tolist()
|
next_token_ids = next_input_ids.tolist()
|
||||||
accepted_ids = accepted_ids.tolist()
|
accepted_ids = accepted_ids.tolist()
|
||||||
|
@ -505,10 +505,15 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
return fsm_grammar_state
|
return fsm_grammar_state
|
||||||
return fsm.next_state(fsm_grammar_state, next_token_id)
|
return fsm.next_state(fsm_grammar_state, next_token_id)
|
||||||
|
|
||||||
|
# TODO: move grammar compilation into the router
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@lru_cache(maxsize=32, typed=True)
|
@lru_cache(maxsize=32, typed=True)
|
||||||
def _cached_compile_fsm(schema, tokenizer):
|
def _cached_compile_fsm(schema, tokenizer):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
# Detect if schema is a json object before converting it to regex.
|
||||||
|
# We need to check if it's a valid json object before converting it to regex
|
||||||
|
# and cannot simply test if it starts with '{' and ends with '}' because there
|
||||||
|
# are valid regexes that start and end with curly braces.
|
||||||
try:
|
try:
|
||||||
json.loads(schema) # check if schema is a valid json
|
json.loads(schema) # check if schema is a valid json
|
||||||
schema = build_regex_from_object(schema) # convert schema to regex
|
schema = build_regex_from_object(schema) # convert schema to regex
|
||||||
|
@ -94,7 +94,7 @@ class NextTokenChooser:
|
|||||||
|
|
||||||
return next_id, next_logprob
|
return next_id, next_logprob
|
||||||
|
|
||||||
def advance_grammar(self, next_id):
|
def advance_grammar(self, next_id: int):
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
self.fsm_grammar_state = self.grammar_processor.advance(
|
self.fsm_grammar_state = self.grammar_processor.advance(
|
||||||
next_id, self.fsm_grammar_state
|
next_id, self.fsm_grammar_state
|
||||||
@ -388,10 +388,10 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
||||||
|
|
||||||
def advance_grammar(self, next_ids: torch.Tensor):
|
def advance_grammar(self, next_ids: List[int]):
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
other_new_states = self.grammar_processor.advance_batch(
|
other_new_states = self.grammar_processor.advance_batch(
|
||||||
next_ids.tolist(), self.fsm_grammar_states, self.grammars
|
next_ids, self.fsm_grammar_states, self.grammars
|
||||||
)
|
)
|
||||||
self.fsm_grammar_states = other_new_states
|
self.fsm_grammar_states = other_new_states
|
||||||
return self
|
return self
|
||||||
|
Loading…
Reference in New Issue
Block a user