diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8b2206dd..a286e41c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1033,6 +1033,7 @@ class FlashCausalLM(Model): cumulative_length += input_length + # Update values batch.next_token_chooser = batch.next_token_chooser.advance_grammar( next_input_ids ) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 1f633f8a..5ea2db87 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -815,6 +815,9 @@ class IdeficsCausalLM(Model): generations.append(generation) # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 774b45c0..4585f4b9 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -694,6 +694,9 @@ class Mamba(Model): generations.append(generation) # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index d7c074c0..459f4256 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -789,6 +789,9 @@ class Seq2SeqLM(Model): generations.append(generation) # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) batch.decoder_input_ids[i] = next_token_id batch.all_decoder_input_ids[i] = all_decoder_input_ids batch.input_lengths[i] = input_length