This commit is contained in:
OlivierDehaene 2023-04-24 16:08:50 +02:00
parent 885411e747
commit 889897fe69
4 changed files with 46 additions and 36 deletions

View File

@ -181,9 +181,7 @@ def test_causal_lm_generate_token_completion_multi(
next_batch = next_batch.filter([next_batch.requests[0]]) next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range( for _ in range(
stopping_criterias[0].max_new_tokens stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
- stopping_criterias[1].max_new_tokens
- 1
): ):
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)

View File

@ -174,14 +174,14 @@ def test_causal_lm_generate_token_completion_multi(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
) )
# Copy stopping_criterias before filtering # Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy() stopping_criterias = (
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
)
next_batch = next_batch.filter([next_batch.requests[0]]) next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range( for _ in range(
stopping_criterias[0].max_new_tokens stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
- stopping_criterias[1].max_new_tokens
- 1
): ):
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)

View File

@ -178,11 +178,12 @@ class CausalLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
remaining_decode_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max( new_padding_right_offset = max(
new_padding_right_offset, new_padding_right_offset, remaining_decode_tokens
remaining_decode_tokens
) )
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
@ -190,8 +191,10 @@ class CausalLMBatch(Batch):
position_ids = self.position_ids[keep_indices] position_ids = self.position_ids[keep_indices]
self.attention_mask = self.attention_mask[ self.attention_mask = self.attention_mask[
keep_indices, keep_indices,
-(self.padding_right_offset + max_input_length): -(self.padding_right_offset + max_input_length) : (
(self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset, self.attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
] ]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
@ -329,7 +332,8 @@ class CausalLMBatch(Batch):
# And ensure that we can update tensors in-place # And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple: if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [ batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
] ]
elif batch.past_key_values[0][0].shape == 3: elif batch.past_key_values[0][0].shape == 3:
for layer in batch.past_key_values: for layer in batch.past_key_values:
@ -390,7 +394,9 @@ class CausalLMBatch(Batch):
start_index = end_index start_index = end_index
padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape) padded_past_values = first_past_kvs[j][1].new_zeros(
padded_past_values_shape
)
start_index = 0 start_index = 0
for batch in batches: for batch in batches:
past_values = batch.past_key_values[j][1] past_values = batch.past_key_values[j][1]
@ -411,7 +417,6 @@ class CausalLMBatch(Batch):
past_key_values.append([padded_past_keys, padded_past_values]) past_key_values.append([padded_past_keys, padded_past_values])
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,

View File

@ -200,7 +200,8 @@ class Seq2SeqLMBatch(Batch):
) )
padding_right_offset = max( padding_right_offset = max(
padding_right_offset, padding_right_offset,
self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens self.stopping_criterias[idx].max_new_tokens
- self.stopping_criterias[idx].current_tokens,
) )
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
@ -216,15 +217,21 @@ class Seq2SeqLMBatch(Batch):
if self.decoder_attention_mask is not None: if self.decoder_attention_mask is not None:
self.decoder_attention_mask = self.decoder_attention_mask[ self.decoder_attention_mask = self.decoder_attention_mask[
keep_indices, keep_indices,
-(self.padding_right_offset + max_decoder_input_length): -(self.padding_right_offset + max_decoder_input_length) : (
(self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset, self.decoder_attention_mask.shape[1] - self.padding_right_offset
)
+ padding_right_offset,
] ]
self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:] self.encoder_last_hidden_state = self.encoder_last_hidden_state[
keep_indices, -max_input_length:
]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple: if type(self.past_key_values[0]) == tuple:
self.past_key_values = [[t for t in layer] for layer in self.past_key_values] self.past_key_values = [
[t for t in layer] for layer in self.past_key_values
]
decoder_past_seq_len = max_decoder_input_length - 1 decoder_past_seq_len = max_decoder_input_length - 1
for layer in self.past_key_values: for layer in self.past_key_values:
@ -255,7 +262,6 @@ class Seq2SeqLMBatch(Batch):
return self return self
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
@ -387,7 +393,9 @@ class Seq2SeqLMBatch(Batch):
# Ensure that we can update tensors in-place # Ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple: if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values] batch.past_key_values = [
[t for t in layer] for layer in batch.past_key_values
]
start_index = end_index start_index = end_index
# Add eventual padding tokens that were added while concatenating # Add eventual padding tokens that were added while concatenating
@ -435,9 +443,9 @@ class Seq2SeqLMBatch(Batch):
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
past_seq_len = batch.max_decoder_input_length - 1 past_seq_len = batch.max_decoder_input_length - 1
padded_past_values[ padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
start_index:end_index, :, -past_seq_len:, : :, :, -past_seq_len:, :
] = t[:, :, -past_seq_len:, :] ]
del t del t
start_index = end_index start_index = end_index
@ -457,13 +465,12 @@ class Seq2SeqLMBatch(Batch):
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
padded_past_values[ padded_past_values[
start_index:end_index, :, -batch.max_input_length:, : start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length:, :] ] = t[:, :, -batch.max_input_length :, :]
del t del t
start_index = end_index start_index = end_index
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,