feat: fix advance_grammar sig, add comment and move advance call

This commit is contained in:
drbh 2024-02-13 17:26:01 +00:00
parent d39e45abc3
commit 3df37fa941
3 changed files with 11 additions and 6 deletions

View File

@ -1034,9 +1034,6 @@ class FlashCausalLM(Model):
cumulative_length += input_length cumulative_length += input_length
# Update values # Update values
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
next_input_ids
)
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
@ -1053,6 +1050,9 @@ class FlashCausalLM(Model):
prefill_logprobs = prefill_logprobs.view(-1).tolist() prefill_logprobs = prefill_logprobs.view(-1).tolist()
# GPU <-> CPU sync # GPU <-> CPU sync
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
next_input_ids.tolist(),
)
next_token_logprobs = next_token_logprobs.tolist() next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist() next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist() accepted_ids = accepted_ids.tolist()

View File

@ -505,10 +505,15 @@ class GrammarLogitProcessor(LogitsProcessor):
return fsm_grammar_state return fsm_grammar_state
return fsm.next_state(fsm_grammar_state, next_token_id) return fsm.next_state(fsm_grammar_state, next_token_id)
# TODO: move grammar compilation into the router
@staticmethod @staticmethod
@lru_cache(maxsize=32, typed=True) @lru_cache(maxsize=32, typed=True)
def _cached_compile_fsm(schema, tokenizer): def _cached_compile_fsm(schema, tokenizer):
start_time = time.time() start_time = time.time()
# Detect if schema is a json object before converting it to regex.
# We need to check if it's a valid json object before converting it to regex
# and cannot simply test if it starts with '{' and ends with '}' because there
# are valid regexes that start and end with curly braces.
try: try:
json.loads(schema) # check if schema is a valid json json.loads(schema) # check if schema is a valid json
schema = build_regex_from_object(schema) # convert schema to regex schema = build_regex_from_object(schema) # convert schema to regex

View File

@ -94,7 +94,7 @@ class NextTokenChooser:
return next_id, next_logprob return next_id, next_logprob
def advance_grammar(self, next_id): def advance_grammar(self, next_id: int):
if self.grammar_processor is not None: if self.grammar_processor is not None:
self.fsm_grammar_state = self.grammar_processor.advance( self.fsm_grammar_state = self.grammar_processor.advance(
next_id, self.fsm_grammar_state next_id, self.fsm_grammar_state
@ -388,10 +388,10 @@ class HeterogeneousNextTokenChooser:
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
def advance_grammar(self, next_ids: torch.Tensor): def advance_grammar(self, next_ids: List[int]):
if self.grammar_processor is not None: if self.grammar_processor is not None:
other_new_states = self.grammar_processor.advance_batch( other_new_states = self.grammar_processor.advance_batch(
next_ids.tolist(), self.fsm_grammar_states, self.grammars next_ids, self.fsm_grammar_states, self.grammars
) )
self.fsm_grammar_states = other_new_states self.fsm_grammar_states = other_new_states
return self return self