mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
fix: include fsm_grammar_states in FlashMistralBatch from_pb
This commit is contained in:
parent
ff42d33e99
commit
2762e6883e
@ -98,6 +98,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
prefill_cu_outlens = [0]
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
fsm_grammar_states = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
@ -136,6 +137,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||
|
||||
next_token_chooser_parameters.append(r.parameters)
|
||||
fsm_grammar_states.append(r.fsm_grammar_state)
|
||||
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
@ -204,7 +206,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device, tokenizer
|
||||
next_token_chooser_parameters, dtype, device, tokenizer, fsm_grammar_states
|
||||
)
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user