Bit more simplification to flash_neox generate_tokens()

This commit is contained in:
Nick Hill 2023-03-27 16:30:11 -07:00
parent 9895569c8b
commit f786d1ddf5

View File

@ -301,8 +301,6 @@ class FlashNeoX(Model):
next_batch_cu_seqlens = [0]
next_batch_past_key_values = []
next_batch_input_lengths = []
next_batch_all_input_ids = []
next_batch_all_input_ids_tensor = []
# Cumulative length
cumulative_length = 0
@ -368,8 +366,6 @@ class FlashNeoX(Model):
next_batch_cu_seqlens[-1] + new_input_length
)
next_batch_input_lengths.append(new_input_length)
next_batch_all_input_ids.append(all_input_ids)
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
# Prefill
if prefill:
@ -411,8 +407,6 @@ class FlashNeoX(Model):
batch.max_seqlen += 1
batch.past_key_values = next_batch_past_key_values
batch.input_lengths = next_batch_input_lengths
batch.all_input_ids = next_batch_all_input_ids
batch.all_input_ids_tensor = next_batch_all_input_ids_tensor
return generations