diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 2e1055b2..b5f0ca65 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -98,6 +98,7 @@ class FlashMistralBatch(FlashCausalLMBatch): prefill_cu_outlens = [0] next_token_chooser_parameters = [] + fsm_grammar_states = [] stopping_criterias = [] top_n_tokens = [] @@ -136,6 +137,7 @@ class FlashMistralBatch(FlashCausalLMBatch): cu_seqlen_prefill.append(cumulative_length + input_length) next_token_chooser_parameters.append(r.parameters) + fsm_grammar_states.append(r.fsm_grammar_state) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -204,7 +206,7 @@ class FlashMistralBatch(FlashCausalLMBatch): ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device, tokenizer + next_token_chooser_parameters, dtype, device, tokenizer, fsm_grammar_states ) start_slots = torch.tensor(start_slots, dtype=torch.int64)