better implem

This commit is contained in:
OlivierDehaene 2023-02-18 17:30:45 +01:00
parent a8446a5a31
commit f2f78e17d1
3 changed files with 31 additions and 72 deletions

View File

@ -37,7 +37,6 @@ class CausalLMBatch(Batch):
# Metadata used for padding # Metadata used for padding
size: int size: int
max_sequence_length: int max_sequence_length: int
max_potential_length: int
padding_right_offset: int padding_right_offset: int
# Past metadata # Past metadata
@ -64,7 +63,6 @@ class CausalLMBatch(Batch):
# Parse batch # Parse batch
max_sequence_length = 0 max_sequence_length = 0
max_potential_length = 0
padding_right_offset = 0 padding_right_offset = 0
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
@ -75,10 +73,9 @@ class CausalLMBatch(Batch):
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
max_sequence_length = max(max_sequence_length, r.input_length) max_sequence_length = max(max_sequence_length, r.input_length)
potential_length = r.input_length + stopping_criteria.max_new_tokens padding_right_offset = max(
if max_potential_length < potential_length: padding_right_offset, stopping_criteria.max_new_tokens
max_potential_length = potential_length )
padding_right_offset = stopping_criteria.max_new_tokens
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
@ -89,7 +86,9 @@ class CausalLMBatch(Batch):
input_ids = tokenized_inputs["input_ids"] input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask # Allocate maximum attention_mask
attention_mask = input_ids.new_zeros((pb.size, max_potential_length)) attention_mask = input_ids.new_zeros(
(pb.size, max_sequence_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask # Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"] attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
@ -110,7 +109,6 @@ class CausalLMBatch(Batch):
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size, size=pb.size,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
max_potential_length=max_potential_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
) )
@ -120,14 +118,11 @@ class CausalLMBatch(Batch):
# Used for padding # Used for padding
total_batch_size = 0 total_batch_size = 0
max_sequence_length = 0 max_sequence_length = 0
max_potential_length = 0
padding_right_offset = 0 padding_right_offset = 0
for batch in batches: for batch in batches:
total_batch_size += batch.size total_batch_size += batch.size
max_sequence_length = max(max_sequence_length, batch.max_sequence_length) max_sequence_length = max(max_sequence_length, batch.max_sequence_length)
if max_potential_length < batch.max_potential_length: padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
max_potential_length = batch.max_potential_length
padding_right_offset = batch.padding_right_offset
# Batch attributes # Batch attributes
requests = [] requests = []
@ -170,21 +165,21 @@ class CausalLMBatch(Batch):
# Create padded tensor # Create padded tensor
if attention_mask is None: if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros( attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_potential_length), (total_batch_size, max_sequence_length + padding_right_offset),
) )
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space # and to remove unused allocated space
left_offset = max_sequence_length - batch.max_sequence_length
batch_left_offset = (
batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset
)
attention_mask[ attention_mask[
start_index:end_index, start_index:end_index,
-( left_offset:-padding_right_offset,
batch.max_sequence_length + padding_right_offset
) : -padding_right_offset,
] = batch.attention_mask[ ] = batch.attention_mask[
:, :,
-( batch_left_offset : -batch.padding_right_offset,
batch.max_sequence_length + batch.padding_right_offset
) : -batch.padding_right_offset,
] ]
# Create empty tensor # Create empty tensor
@ -263,7 +258,6 @@ class CausalLMBatch(Batch):
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size, size=total_batch_size,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
max_potential_length=max_potential_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
) )
@ -332,10 +326,7 @@ class CausalLM(Model):
self, batch: CausalLMBatch self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]: ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
if batch.padding_right_offset != 0: attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
else:
attention_mask = batch.attention_mask
logits, past = self.forward( logits, past = self.forward(
batch.input_ids, batch.input_ids,
@ -355,7 +346,6 @@ class CausalLM(Model):
# Metadata # Metadata
next_batch_size = 0 next_batch_size = 0
next_batch_max_sequence_length = 0 next_batch_max_sequence_length = 0
next_batch_max_potential_length = 0
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
@ -428,13 +418,6 @@ class CausalLM(Model):
next_batch_max_sequence_length = max( next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length next_batch_max_sequence_length, new_input_length
) )
# potential length is input_length + max_new_tokens but we need to remove generated tokens
next_batch_max_potential_length = max(
next_batch_max_potential_length,
new_input_length
+ stopping_criteria.max_new_tokens
- stopping_criteria.current_tokens,
)
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
@ -518,7 +501,6 @@ class CausalLM(Model):
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size, size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length, max_sequence_length=next_batch_max_sequence_length,
max_potential_length=next_batch_max_potential_length,
padding_right_offset=batch.padding_right_offset - 1, padding_right_offset=batch.padding_right_offset - 1,
keys_head_dim_last=batch.keys_head_dim_last, keys_head_dim_last=batch.keys_head_dim_last,
) )

View File

@ -106,12 +106,10 @@ class GalacticaCausalLMBatch(CausalLMBatch):
) )
# Tokenize batch # Tokenize batch
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False, return_token_type_ids=False,
).to(device) ).to(device)
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1

View File

@ -42,7 +42,6 @@ class Seq2SeqLMBatch(Batch):
size: int size: int
max_input_length: int max_input_length: int
max_decoder_input_length: int max_decoder_input_length: int
max_potential_length: int
padding_right_offset: int padding_right_offset: int
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
@ -71,7 +70,6 @@ class Seq2SeqLMBatch(Batch):
# Parse batch # Parse batch
max_input_length = 0 max_input_length = 0
max_potential_length = 0
padding_right_offset = 0 padding_right_offset = 0
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
@ -85,18 +83,15 @@ class Seq2SeqLMBatch(Batch):
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
max_input_length = max(max_input_length, r.input_length) max_input_length = max(max_input_length, r.input_length)
if max_potential_length < stopping_criteria.max_new_tokens + 1: padding_right_offset = max(
# +1 because we have the bos token padding_right_offset, stopping_criteria.max_new_tokens
max_potential_length = stopping_criteria.max_new_tokens + 1 )
padding_right_offset = stopping_criteria.max_new_tokens
# Tokenize batch # Tokenize batch
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False, return_token_type_ids=False,
).to(device) ).to(device)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1] # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
@ -118,7 +113,6 @@ class Seq2SeqLMBatch(Batch):
size=len(pb.requests), size=len(pb.requests),
max_input_length=max(input_lengths), max_input_length=max(input_lengths),
max_decoder_input_length=1, max_decoder_input_length=1,
max_potential_length=max_potential_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
) )
@ -131,7 +125,6 @@ class Seq2SeqLMBatch(Batch):
total_batch_size = 0 total_batch_size = 0
max_input_length = 0 max_input_length = 0
max_decoder_input_length = 0 max_decoder_input_length = 0
max_potential_length = 0
padding_right_offset = 0 padding_right_offset = 0
for batch in batches: for batch in batches:
total_batch_size += batch.size total_batch_size += batch.size
@ -139,9 +132,7 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length = max( max_decoder_input_length = max(
max_decoder_input_length, batch.max_decoder_input_length max_decoder_input_length, batch.max_decoder_input_length
) )
if max_potential_length < batch.max_potential_length: padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
max_potential_length = batch.max_potential_length
padding_right_offset = batch.padding_right_offset
# Batch attributes # Batch attributes
requests = [] requests = []
@ -200,29 +191,28 @@ class Seq2SeqLMBatch(Batch):
if decoder_attention_mask is None: if decoder_attention_mask is None:
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
decoder_attention_mask = batch.attention_mask.new_zeros( decoder_attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_potential_length), (total_batch_size, max_decoder_input_length + padding_right_offset),
) )
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
# this batch. All generations are of length `batch.max_decoder_input_length`. # this batch. All generations are of length `batch.max_decoder_input_length`.
left_offset = max_decoder_input_length - batch.max_decoder_input_length
if batch.decoder_attention_mask is None: if batch.decoder_attention_mask is None:
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, start_index:end_index,
-( left_offset:-padding_right_offset,
batch.max_decoder_input_length + padding_right_offset
) : -padding_right_offset,
] = 1 ] = 1
# If it exists, we need to index # If it exists, we need to index
else: else:
batch_left_offset = (
batch.decoder_attention_mask.shape[1]
- batch.max_decoder_input_length - batch.padding_right_offset
)
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, start_index:end_index,
-( left_offset:-padding_right_offset,
batch.max_decoder_input_length + padding_right_offset
) : -padding_right_offset,
] = batch.decoder_attention_mask[ ] = batch.decoder_attention_mask[
:, :,
-( batch_left_offset : -batch.padding_right_offset,
batch.max_decoder_input_length + batch.padding_right_offset
) : -batch.padding_right_offset,
] ]
# Create padded tensor # Create padded tensor
@ -308,7 +298,6 @@ class Seq2SeqLMBatch(Batch):
size=total_batch_size, size=total_batch_size,
max_input_length=max_input_length, max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
max_potential_length=max_potential_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
) )
@ -387,12 +376,9 @@ class Seq2SeqLM(Model):
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
if batch.decoder_attention_mask is not None: if batch.decoder_attention_mask is not None:
# slice to the correct shape # slice to the correct shape
if batch.padding_right_offset != 0: decoder_attention_mask = batch.decoder_attention_mask[
decoder_attention_mask = batch.decoder_attention_mask[ :, : -batch.padding_right_offset
:, : -batch.padding_right_offset ]
]
else:
decoder_attention_mask = batch.decoder_attention_mask
else: else:
decoder_attention_mask = None decoder_attention_mask = None
@ -431,7 +417,6 @@ class Seq2SeqLM(Model):
next_batch_size = 0 next_batch_size = 0
next_batch_max_input_length = 0 next_batch_max_input_length = 0
next_batch_max_decoder_input_length = 0 next_batch_max_decoder_input_length = 0
next_batch_max_potential_length = 0
# Finished requests # Finished requests
generations: List[Generation] = [] generations: List[Generation] = []
@ -506,11 +491,6 @@ class Seq2SeqLM(Model):
next_batch_max_decoder_input_length = max( next_batch_max_decoder_input_length = max(
next_batch_max_decoder_input_length, new_decoder_input_length next_batch_max_decoder_input_length, new_decoder_input_length
) )
# +1 because of the bos token
next_batch_max_potential_length = max(
next_batch_max_potential_length,
stopping_criteria.max_new_tokens + 1,
)
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
@ -598,7 +578,6 @@ class Seq2SeqLM(Model):
size=next_batch_size, size=next_batch_size,
max_input_length=next_batch_max_input_length, max_input_length=next_batch_max_input_length,
max_decoder_input_length=next_batch_max_decoder_input_length, max_decoder_input_length=next_batch_max_decoder_input_length,
max_potential_length=next_batch_max_potential_length,
padding_right_offset=batch.padding_right_offset - 1, padding_right_offset=batch.padding_right_offset - 1,
) )
return generations, next_batch return generations, next_batch