From f2f78e17d149a10a1c23dfde02d40738f4dc22c0 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Sat, 18 Feb 2023 17:30:45 +0100 Subject: [PATCH] better implem --- server/text_generation/models/causal_lm.py | 48 ++++++------------- server/text_generation/models/galactica.py | 2 - server/text_generation/models/seq2seq_lm.py | 53 +++++++-------------- 3 files changed, 31 insertions(+), 72 deletions(-) diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 040d7a6a..e109b83b 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -37,7 +37,6 @@ class CausalLMBatch(Batch): # Metadata used for padding size: int max_sequence_length: int - max_potential_length: int padding_right_offset: int # Past metadata @@ -64,7 +63,6 @@ class CausalLMBatch(Batch): # Parse batch max_sequence_length = 0 - max_potential_length = 0 padding_right_offset = 0 for r in pb.requests: inputs.append(r.inputs) @@ -75,10 +73,9 @@ class CausalLMBatch(Batch): ) stopping_criterias.append(stopping_criteria) max_sequence_length = max(max_sequence_length, r.input_length) - potential_length = r.input_length + stopping_criteria.max_new_tokens - if max_potential_length < potential_length: - max_potential_length = potential_length - padding_right_offset = stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) tokenized_inputs = tokenizer( inputs, @@ -89,7 +86,9 @@ class CausalLMBatch(Batch): input_ids = tokenized_inputs["input_ids"] # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros((pb.size, max_potential_length)) + attention_mask = input_ids.new_zeros( + (pb.size, max_sequence_length + padding_right_offset) + ) # Copy tokenizer attention_mask into fully allocated attention_mask attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"] @@ -110,7 +109,6 @@ class CausalLMBatch(Batch): stopping_criterias=stopping_criterias, size=pb.size, max_sequence_length=max_sequence_length, - max_potential_length=max_potential_length, padding_right_offset=padding_right_offset, ) @@ -120,14 +118,11 @@ class CausalLMBatch(Batch): # Used for padding total_batch_size = 0 max_sequence_length = 0 - max_potential_length = 0 padding_right_offset = 0 for batch in batches: total_batch_size += batch.size max_sequence_length = max(max_sequence_length, batch.max_sequence_length) - if max_potential_length < batch.max_potential_length: - max_potential_length = batch.max_potential_length - padding_right_offset = batch.padding_right_offset + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) # Batch attributes requests = [] @@ -170,21 +165,21 @@ class CausalLMBatch(Batch): # Create padded tensor if attention_mask is None: attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_potential_length), + (total_batch_size, max_sequence_length + padding_right_offset), ) # We need to slice the attention mask to remove padding from previous steps # and to remove unused allocated space + left_offset = max_sequence_length - batch.max_sequence_length + batch_left_offset = ( + batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset + ) attention_mask[ start_index:end_index, - -( - batch.max_sequence_length + padding_right_offset - ) : -padding_right_offset, + left_offset:-padding_right_offset, ] = batch.attention_mask[ :, - -( - batch.max_sequence_length + batch.padding_right_offset - ) : -batch.padding_right_offset, + batch_left_offset : -batch.padding_right_offset, ] # Create empty tensor @@ -263,7 +258,6 @@ class CausalLMBatch(Batch): stopping_criterias=stopping_criterias, size=total_batch_size, max_sequence_length=max_sequence_length, - max_potential_length=max_potential_length, padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, ) @@ -332,10 +326,7 @@ class CausalLM(Model): self, batch: CausalLMBatch ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: # slice the attention mask to the correct shape - if batch.padding_right_offset != 0: - attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - else: - attention_mask = batch.attention_mask + attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] logits, past = self.forward( batch.input_ids, @@ -355,7 +346,6 @@ class CausalLM(Model): # Metadata next_batch_size = 0 next_batch_max_sequence_length = 0 - next_batch_max_potential_length = 0 # Results generations: List[Generation] = [] @@ -428,13 +418,6 @@ class CausalLM(Model): next_batch_max_sequence_length = max( next_batch_max_sequence_length, new_input_length ) - # potential length is input_length + max_new_tokens but we need to remove generated tokens - next_batch_max_potential_length = max( - next_batch_max_potential_length, - new_input_length - + stopping_criteria.max_new_tokens - - stopping_criteria.current_tokens, - ) # Prefill if stopping_criteria.current_tokens == 1: @@ -518,7 +501,6 @@ class CausalLM(Model): stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, max_sequence_length=next_batch_max_sequence_length, - max_potential_length=next_batch_max_potential_length, padding_right_offset=batch.padding_right_offset - 1, keys_head_dim_last=batch.keys_head_dim_last, ) diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 780a94f1..e4b861c8 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -106,12 +106,10 @@ class GalacticaCausalLMBatch(CausalLMBatch): ) # Tokenize batch - pad_to_multiple_of = 8 if device.type == "cuda" else None tokenized_inputs = tokenizer( inputs, return_tensors="pt", padding=True, - pad_to_multiple_of=pad_to_multiple_of, return_token_type_ids=False, ).to(device) position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 837fd0d1..4813764b 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -42,7 +42,6 @@ class Seq2SeqLMBatch(Batch): size: int max_input_length: int max_decoder_input_length: int - max_potential_length: int padding_right_offset: int def to_pb(self) -> generate_pb2.Batch: @@ -71,7 +70,6 @@ class Seq2SeqLMBatch(Batch): # Parse batch max_input_length = 0 - max_potential_length = 0 padding_right_offset = 0 for r in pb.requests: inputs.append(r.inputs) @@ -85,18 +83,15 @@ class Seq2SeqLMBatch(Batch): ) stopping_criterias.append(stopping_criteria) max_input_length = max(max_input_length, r.input_length) - if max_potential_length < stopping_criteria.max_new_tokens + 1: - # +1 because we have the bos token - max_potential_length = stopping_criteria.max_new_tokens + 1 - padding_right_offset = stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) # Tokenize batch - pad_to_multiple_of = 8 if device.type == "cuda" else None tokenized_inputs = tokenizer( inputs, return_tensors="pt", padding=True, - pad_to_multiple_of=pad_to_multiple_of, return_token_type_ids=False, ).to(device) # Convert decoder_input_ids to torch tensor of size [batch_size, 1] @@ -118,7 +113,6 @@ class Seq2SeqLMBatch(Batch): size=len(pb.requests), max_input_length=max(input_lengths), max_decoder_input_length=1, - max_potential_length=max_potential_length, padding_right_offset=padding_right_offset, ) @@ -131,7 +125,6 @@ class Seq2SeqLMBatch(Batch): total_batch_size = 0 max_input_length = 0 max_decoder_input_length = 0 - max_potential_length = 0 padding_right_offset = 0 for batch in batches: total_batch_size += batch.size @@ -139,9 +132,7 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length = max( max_decoder_input_length, batch.max_decoder_input_length ) - if max_potential_length < batch.max_potential_length: - max_potential_length = batch.max_potential_length - padding_right_offset = batch.padding_right_offset + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) # Batch attributes requests = [] @@ -200,29 +191,28 @@ class Seq2SeqLMBatch(Batch): if decoder_attention_mask is None: # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here decoder_attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_potential_length), + (total_batch_size, max_decoder_input_length + padding_right_offset), ) # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated # this batch. All generations are of length `batch.max_decoder_input_length`. + left_offset = max_decoder_input_length - batch.max_decoder_input_length if batch.decoder_attention_mask is None: decoder_attention_mask[ start_index:end_index, - -( - batch.max_decoder_input_length + padding_right_offset - ) : -padding_right_offset, + left_offset:-padding_right_offset, ] = 1 # If it exists, we need to index else: + batch_left_offset = ( + batch.decoder_attention_mask.shape[1] + - batch.max_decoder_input_length - batch.padding_right_offset + ) decoder_attention_mask[ start_index:end_index, - -( - batch.max_decoder_input_length + padding_right_offset - ) : -padding_right_offset, + left_offset:-padding_right_offset, ] = batch.decoder_attention_mask[ :, - -( - batch.max_decoder_input_length + batch.padding_right_offset - ) : -batch.padding_right_offset, + batch_left_offset : -batch.padding_right_offset, ] # Create padded tensor @@ -308,7 +298,6 @@ class Seq2SeqLMBatch(Batch): size=total_batch_size, max_input_length=max_input_length, max_decoder_input_length=max_decoder_input_length, - max_potential_length=max_potential_length, padding_right_offset=padding_right_offset, ) @@ -387,12 +376,9 @@ class Seq2SeqLM(Model): ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: if batch.decoder_attention_mask is not None: # slice to the correct shape - if batch.padding_right_offset != 0: - decoder_attention_mask = batch.decoder_attention_mask[ - :, : -batch.padding_right_offset - ] - else: - decoder_attention_mask = batch.decoder_attention_mask + decoder_attention_mask = batch.decoder_attention_mask[ + :, : -batch.padding_right_offset + ] else: decoder_attention_mask = None @@ -431,7 +417,6 @@ class Seq2SeqLM(Model): next_batch_size = 0 next_batch_max_input_length = 0 next_batch_max_decoder_input_length = 0 - next_batch_max_potential_length = 0 # Finished requests generations: List[Generation] = [] @@ -506,11 +491,6 @@ class Seq2SeqLM(Model): next_batch_max_decoder_input_length = max( next_batch_max_decoder_input_length, new_decoder_input_length ) - # +1 because of the bos token - next_batch_max_potential_length = max( - next_batch_max_potential_length, - stopping_criteria.max_new_tokens + 1, - ) # Prefill if stopping_criteria.current_tokens == 1: @@ -598,7 +578,6 @@ class Seq2SeqLM(Model): size=next_batch_size, max_input_length=next_batch_max_input_length, max_decoder_input_length=next_batch_max_decoder_input_length, - max_potential_length=next_batch_max_potential_length, padding_right_offset=batch.padding_right_offset - 1, ) return generations, next_batch