fix: include fsm_grammar_states in FlashMistralBatch from_pb

This commit is contained in:
drbh 2024-04-08 17:23:46 +00:00
parent ff42d33e99
commit 2762e6883e

View File

@ -98,6 +98,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
prefill_cu_outlens = [0]
next_token_chooser_parameters = []
fsm_grammar_states = []
stopping_criterias = []
top_n_tokens = []
@ -136,6 +137,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
cu_seqlen_prefill.append(cumulative_length + input_length)
next_token_chooser_parameters.append(r.parameters)
fsm_grammar_states.append(r.fsm_grammar_state)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
@ -204,7 +206,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
next_token_chooser_parameters, dtype, device, tokenizer, fsm_grammar_states
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)