mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: advance grammars in all models
This commit is contained in:
parent
8f14019053
commit
91a114a490
@ -1033,6 +1033,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
cumulative_length += input_length
|
||||
|
||||
# Update values
|
||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
|
||||
next_input_ids
|
||||
)
|
||||
|
@ -815,6 +815,9 @@ class IdeficsCausalLM(Model):
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||
next_token_id_squeezed.item()
|
||||
)
|
||||
batch.input_ids[i, 0] = next_token_id
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
|
@ -694,6 +694,9 @@ class Mamba(Model):
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||
next_token_id_squeezed.item()
|
||||
)
|
||||
batch.input_ids[i, 0] = next_token_id
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
|
@ -789,6 +789,9 @@ class Seq2SeqLM(Model):
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||
next_token_id_squeezed.item()
|
||||
)
|
||||
batch.decoder_input_ids[i] = next_token_id
|
||||
batch.all_decoder_input_ids[i] = all_decoder_input_ids
|
||||
batch.input_lengths[i] = input_length
|
||||
|
Loading…
Reference in New Issue
Block a user