mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: fix advance_grammar sig, add comment and move advance call
This commit is contained in:
parent
d39e45abc3
commit
3df37fa941
@ -1034,9 +1034,6 @@ class FlashCausalLM(Model):
|
||||
cumulative_length += input_length
|
||||
|
||||
# Update values
|
||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
|
||||
next_input_ids
|
||||
)
|
||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids = next_position_ids + accepted_ids
|
||||
@ -1053,6 +1050,9 @@ class FlashCausalLM(Model):
|
||||
prefill_logprobs = prefill_logprobs.view(-1).tolist()
|
||||
|
||||
# GPU <-> CPU sync
|
||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
|
||||
next_input_ids.tolist(),
|
||||
)
|
||||
next_token_logprobs = next_token_logprobs.tolist()
|
||||
next_token_ids = next_input_ids.tolist()
|
||||
accepted_ids = accepted_ids.tolist()
|
||||
|
@ -505,10 +505,15 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
return fsm_grammar_state
|
||||
return fsm.next_state(fsm_grammar_state, next_token_id)
|
||||
|
||||
# TODO: move grammar compilation into the router
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=32, typed=True)
|
||||
def _cached_compile_fsm(schema, tokenizer):
|
||||
start_time = time.time()
|
||||
# Detect if schema is a json object before converting it to regex.
|
||||
# We need to check if it's a valid json object before converting it to regex
|
||||
# and cannot simply test if it starts with '{' and ends with '}' because there
|
||||
# are valid regexes that start and end with curly braces.
|
||||
try:
|
||||
json.loads(schema) # check if schema is a valid json
|
||||
schema = build_regex_from_object(schema) # convert schema to regex
|
||||
|
@ -94,7 +94,7 @@ class NextTokenChooser:
|
||||
|
||||
return next_id, next_logprob
|
||||
|
||||
def advance_grammar(self, next_id):
|
||||
def advance_grammar(self, next_id: int):
|
||||
if self.grammar_processor is not None:
|
||||
self.fsm_grammar_state = self.grammar_processor.advance(
|
||||
next_id, self.fsm_grammar_state
|
||||
@ -388,10 +388,10 @@ class HeterogeneousNextTokenChooser:
|
||||
|
||||
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
||||
|
||||
def advance_grammar(self, next_ids: torch.Tensor):
|
||||
def advance_grammar(self, next_ids: List[int]):
|
||||
if self.grammar_processor is not None:
|
||||
other_new_states = self.grammar_processor.advance_batch(
|
||||
next_ids.tolist(), self.fsm_grammar_states, self.grammars
|
||||
next_ids, self.fsm_grammar_states, self.grammars
|
||||
)
|
||||
self.fsm_grammar_states = other_new_states
|
||||
return self
|
||||
|
Loading…
Reference in New Issue
Block a user