Some more simplification, fix flash_neox cu_seqlen pruning

This commit is contained in:
Nick Hill 2023-03-30 18:10:18 -07:00
parent f786d1ddf5
commit 65ff6a73b3
3 changed files with 10 additions and 14 deletions

View File

@ -281,7 +281,7 @@ class CausalLMBatch(Batch):
#TODO maybe a single loop for all these list slices #TODO maybe a single loop for all these list slices
slice_list = itemgetter(*keep_indices) if new_size > 1 else lambda l: (l[keep_indices[0]],) slice_list = itemgetter(*keep_indices) if new_size > 1 else lambda l: (l[keep_indices[0]],)
batch.input_lengths = slice_list(batch.input_lengths) batch.input_lengths = list(slice_list(batch.input_lengths))
batch.requests = slice_list(batch.requests) batch.requests = slice_list(batch.requests)
batch.all_input_ids = slice_list(batch.all_input_ids) batch.all_input_ids = slice_list(batch.all_input_ids)
batch.next_token_choosers = slice_list(batch.next_token_choosers) batch.next_token_choosers = slice_list(batch.next_token_choosers)
@ -366,7 +366,6 @@ class CausalLM(Model):
) )
# New values for next forward # New values for next forward
next_batch_input_lengths = []
next_batch_input_ids = [] next_batch_input_ids = []
next_batch_all_input_ids = [] next_batch_all_input_ids = []
@ -405,7 +404,7 @@ class CausalLM(Model):
next_batch_input_ids.append(next_token_id) next_batch_input_ids.append(next_token_id)
next_batch_all_input_ids.append(all_input_ids) next_batch_all_input_ids.append(all_input_ids)
next_batch_input_lengths.append(new_input_length) batch.input_lengths[i] = new_input_length
# Prefill # Prefill
if prefill: if prefill:
@ -437,7 +436,6 @@ class CausalLM(Model):
batch.input_ids = torch.cat(next_batch_input_ids, dim=0) batch.input_ids = torch.cat(next_batch_input_ids, dim=0)
batch.past_key_values = past batch.past_key_values = past
batch.all_input_ids = next_batch_all_input_ids batch.all_input_ids = next_batch_all_input_ids
batch.input_lengths = next_batch_input_lengths
batch.max_input_length += 1 batch.max_input_length += 1
batch.padding_right_offset -= 1 batch.padding_right_offset -= 1

View File

@ -209,7 +209,7 @@ class FlashNeoXBatch(Batch):
#TODO maybe a single loop for all these list slices #TODO maybe a single loop for all these list slices
slice_list = itemgetter(*keep_indices) if new_size > 1 else lambda l: (l[keep_indices[0]],) slice_list = itemgetter(*keep_indices) if new_size > 1 else lambda l: (l[keep_indices[0]],)
batch.input_lengths = slice_list(batch.input_lengths) batch.input_lengths = list(slice_list(batch.input_lengths))
batch.requests = slice_list(batch.requests) batch.requests = slice_list(batch.requests)
batch.all_input_ids = slice_list(batch.all_input_ids) batch.all_input_ids = slice_list(batch.all_input_ids)
batch.next_token_choosers = slice_list(batch.next_token_choosers) batch.next_token_choosers = slice_list(batch.next_token_choosers)
@ -221,7 +221,11 @@ class FlashNeoXBatch(Batch):
batch.position_ids = batch.position_ids[keep_indices] batch.position_ids = batch.position_ids[keep_indices]
batch.past_key_values = batch.past_key_values[:, keep_indices] \ batch.past_key_values = batch.past_key_values[:, keep_indices] \
if batch.past_key_values is not None else None if batch.past_key_values is not None else None
batch.cu_seqlens = batch.cu_seqlens[keep_indices]
# Recalculate cumulative seq lengths
new_cu_seqlens = batch.cu_seqlens.new_tensor(batch.input_lengths)
torch.cumsum(new_cu_seqlens, dim=0, out=new_cu_seqlens)
batch.cu_seqlens = torch.cat((batch.cu_seqlens[:1], new_cu_seqlens))
return batch return batch
@ -300,7 +304,6 @@ class FlashNeoX(Model):
next_batch_position_ids = [] next_batch_position_ids = []
next_batch_cu_seqlens = [0] next_batch_cu_seqlens = [0]
next_batch_past_key_values = [] next_batch_past_key_values = []
next_batch_input_lengths = []
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
@ -365,7 +368,7 @@ class FlashNeoX(Model):
next_batch_cu_seqlens.append( next_batch_cu_seqlens.append(
next_batch_cu_seqlens[-1] + new_input_length next_batch_cu_seqlens[-1] + new_input_length
) )
next_batch_input_lengths.append(new_input_length) batch.input_lengths[i] = new_input_length
# Prefill # Prefill
if prefill: if prefill:
@ -406,7 +409,6 @@ class FlashNeoX(Model):
batch.cu_seqlens = next_batch_cu_seqlens batch.cu_seqlens = next_batch_cu_seqlens
batch.max_seqlen += 1 batch.max_seqlen += 1
batch.past_key_values = next_batch_past_key_values batch.past_key_values = next_batch_past_key_values
batch.input_lengths = next_batch_input_lengths
return generations return generations

View File

@ -444,7 +444,6 @@ class Seq2SeqLM(Model):
# New values for next forward # New values for next forward
next_batch_decoder_input_ids = [] next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = []
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
@ -452,7 +451,6 @@ class Seq2SeqLM(Model):
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.decoder_input_lengths,
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.decoder_input_ids, batch.decoder_input_ids,
@ -461,7 +459,6 @@ class Seq2SeqLM(Model):
# For each member of the batch # For each member of the batch
for i, ( for i, (
request, request,
decoder_input_length,
logits, logits,
next_token_chooser, next_token_chooser,
decoder_input_ids, decoder_input_ids,
@ -479,7 +476,7 @@ class Seq2SeqLM(Model):
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_decoder_input_lengths.append(decoder_input_length + 1) batch.decoder_input_lengths[i] += 1
# Prefill # Prefill
if prefill: if prefill:
@ -508,7 +505,6 @@ class Seq2SeqLM(Model):
batch.decoder_input_ids = torch.cat(next_batch_decoder_input_ids) batch.decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
batch.encoder_last_hidden_state = encoder_last_hidden_state batch.encoder_last_hidden_state = encoder_last_hidden_state
batch.past_key_values = past batch.past_key_values = past
batch.decoder_input_lengths = next_batch_decoder_input_lengths
batch.max_decoder_input_length += 1 batch.max_decoder_input_length += 1
batch.padding_right_offset -= 1 batch.padding_right_offset -= 1