mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix: include fsm_grammar_states in FlashMistralBatch from_pb
This commit is contained in:
parent
ff42d33e99
commit
2762e6883e
@ -98,6 +98,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
prefill_cu_outlens = [0]
|
prefill_cu_outlens = [0]
|
||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
|
fsm_grammar_states = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
top_n_tokens = []
|
top_n_tokens = []
|
||||||
|
|
||||||
@ -136,6 +137,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||||
|
|
||||||
next_token_chooser_parameters.append(r.parameters)
|
next_token_chooser_parameters.append(r.parameters)
|
||||||
|
fsm_grammar_states.append(r.fsm_grammar_state)
|
||||||
|
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
@ -204,7 +206,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device, tokenizer
|
next_token_chooser_parameters, dtype, device, tokenizer, fsm_grammar_states
|
||||||
)
|
)
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user