mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: advance grammars in all models
This commit is contained in:
parent
8f14019053
commit
91a114a490
@ -1033,6 +1033,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
|
# Update values
|
||||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
|
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
|
||||||
next_input_ids
|
next_input_ids
|
||||||
)
|
)
|
||||||
|
@ -815,6 +815,9 @@ class IdeficsCausalLM(Model):
|
|||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||||
|
next_token_id_squeezed.item()
|
||||||
|
)
|
||||||
batch.input_ids[i, 0] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
|
@ -694,6 +694,9 @@ class Mamba(Model):
|
|||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||||
|
next_token_id_squeezed.item()
|
||||||
|
)
|
||||||
batch.input_ids[i, 0] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
|
@ -789,6 +789,9 @@ class Seq2SeqLM(Model):
|
|||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||||
|
next_token_id_squeezed.item()
|
||||||
|
)
|
||||||
batch.decoder_input_ids[i] = next_token_id
|
batch.decoder_input_ids[i] = next_token_id
|
||||||
batch.all_decoder_input_ids[i] = all_decoder_input_ids
|
batch.all_decoder_input_ids[i] = all_decoder_input_ids
|
||||||
batch.input_lengths[i] = input_length
|
batch.input_lengths[i] = input_length
|
||||||
|
Loading…
Reference in New Issue
Block a user