feat: advance grammars in all models

This commit is contained in:
drbh 2024-02-13 00:33:36 +00:00
parent 8f14019053
commit 91a114a490
4 changed files with 10 additions and 0 deletions

View File

@ -1033,6 +1033,7 @@ class FlashCausalLM(Model):
cumulative_length += input_length
# Update values
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
next_input_ids
)

View File

@ -815,6 +815,9 @@ class IdeficsCausalLM(Model):
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length

View File

@ -694,6 +694,9 @@ class Mamba(Model):
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length

View File

@ -789,6 +789,9 @@ class Seq2SeqLM(Model):
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.decoder_input_ids[i] = next_token_id
batch.all_decoder_input_ids[i] = all_decoder_input_ids
batch.input_lengths[i] = input_length