From 65ff6a73b3d5a3bb0cafa3d667b0d077bfcfc3f5 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 30 Mar 2023 18:10:18 -0700 Subject: [PATCH] Some more simplification, fix flash_neox cu_seqlen pruning --- server/text_generation_server/models/causal_lm.py | 6 ++---- server/text_generation_server/models/flash_neox.py | 12 +++++++----- server/text_generation_server/models/seq2seq_lm.py | 6 +----- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 72c53647..a4822317 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -281,7 +281,7 @@ class CausalLMBatch(Batch): #TODO maybe a single loop for all these list slices slice_list = itemgetter(*keep_indices) if new_size > 1 else lambda l: (l[keep_indices[0]],) - batch.input_lengths = slice_list(batch.input_lengths) + batch.input_lengths = list(slice_list(batch.input_lengths)) batch.requests = slice_list(batch.requests) batch.all_input_ids = slice_list(batch.all_input_ids) batch.next_token_choosers = slice_list(batch.next_token_choosers) @@ -366,7 +366,6 @@ class CausalLM(Model): ) # New values for next forward - next_batch_input_lengths = [] next_batch_input_ids = [] next_batch_all_input_ids = [] @@ -405,7 +404,7 @@ class CausalLM(Model): next_batch_input_ids.append(next_token_id) next_batch_all_input_ids.append(all_input_ids) - next_batch_input_lengths.append(new_input_length) + batch.input_lengths[i] = new_input_length # Prefill if prefill: @@ -437,7 +436,6 @@ class CausalLM(Model): batch.input_ids = torch.cat(next_batch_input_ids, dim=0) batch.past_key_values = past batch.all_input_ids = next_batch_all_input_ids - batch.input_lengths = next_batch_input_lengths batch.max_input_length += 1 batch.padding_right_offset -= 1 diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 04a44d2f..dc901263 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -209,7 +209,7 @@ class FlashNeoXBatch(Batch): #TODO maybe a single loop for all these list slices slice_list = itemgetter(*keep_indices) if new_size > 1 else lambda l: (l[keep_indices[0]],) - batch.input_lengths = slice_list(batch.input_lengths) + batch.input_lengths = list(slice_list(batch.input_lengths)) batch.requests = slice_list(batch.requests) batch.all_input_ids = slice_list(batch.all_input_ids) batch.next_token_choosers = slice_list(batch.next_token_choosers) @@ -221,7 +221,11 @@ class FlashNeoXBatch(Batch): batch.position_ids = batch.position_ids[keep_indices] batch.past_key_values = batch.past_key_values[:, keep_indices] \ if batch.past_key_values is not None else None - batch.cu_seqlens = batch.cu_seqlens[keep_indices] + + # Recalculate cumulative seq lengths + new_cu_seqlens = batch.cu_seqlens.new_tensor(batch.input_lengths) + torch.cumsum(new_cu_seqlens, dim=0, out=new_cu_seqlens) + batch.cu_seqlens = torch.cat((batch.cu_seqlens[:1], new_cu_seqlens)) return batch @@ -300,7 +304,6 @@ class FlashNeoX(Model): next_batch_position_ids = [] next_batch_cu_seqlens = [0] next_batch_past_key_values = [] - next_batch_input_lengths = [] # Cumulative length cumulative_length = 0 @@ -365,7 +368,7 @@ class FlashNeoX(Model): next_batch_cu_seqlens.append( next_batch_cu_seqlens[-1] + new_input_length ) - next_batch_input_lengths.append(new_input_length) + batch.input_lengths[i] = new_input_length # Prefill if prefill: @@ -406,7 +409,6 @@ class FlashNeoX(Model): batch.cu_seqlens = next_batch_cu_seqlens batch.max_seqlen += 1 batch.past_key_values = next_batch_past_key_values - batch.input_lengths = next_batch_input_lengths return generations diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 27c7d538..5d4b0386 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -444,7 +444,6 @@ class Seq2SeqLM(Model): # New values for next forward next_batch_decoder_input_ids = [] - next_batch_decoder_input_lengths = [] # Results generations: List[Generation] = [] @@ -452,7 +451,6 @@ class Seq2SeqLM(Model): # Zipped iterator iterator = zip( batch.requests, - batch.decoder_input_lengths, logits, batch.next_token_choosers, batch.decoder_input_ids, @@ -461,7 +459,6 @@ class Seq2SeqLM(Model): # For each member of the batch for i, ( request, - decoder_input_length, logits, next_token_chooser, decoder_input_ids, @@ -479,7 +476,7 @@ class Seq2SeqLM(Model): next_token_id_squeezed = next_token_id.squeeze() next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) - next_batch_decoder_input_lengths.append(decoder_input_length + 1) + batch.decoder_input_lengths[i] += 1 # Prefill if prefill: @@ -508,7 +505,6 @@ class Seq2SeqLM(Model): batch.decoder_input_ids = torch.cat(next_batch_decoder_input_ids) batch.encoder_last_hidden_state = encoder_last_hidden_state batch.past_key_values = past - batch.decoder_input_lengths = next_batch_decoder_input_lengths batch.max_decoder_input_length += 1 batch.padding_right_offset -= 1