This commit is contained in:
OlivierDehaene 2023-04-24 16:08:50 +02:00
parent 885411e747
commit 889897fe69
4 changed files with 46 additions and 36 deletions

View File

@ -181,9 +181,7 @@ def test_causal_lm_generate_token_completion_multi(
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range(
stopping_criterias[0].max_new_tokens
- stopping_criterias[1].max_new_tokens
- 1
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)

View File

@ -174,14 +174,14 @@ def test_causal_lm_generate_token_completion_multi(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy()
stopping_criterias = (
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
)
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range(
stopping_criterias[0].max_new_tokens
- stopping_criterias[1].max_new_tokens
- 1
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)

View File

@ -178,11 +178,12 @@ class CausalLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
remaining_decode_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset,
remaining_decode_tokens
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
@ -190,8 +191,10 @@ class CausalLMBatch(Batch):
position_ids = self.position_ids[keep_indices]
self.attention_mask = self.attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length):
(self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset,
-(self.padding_right_offset + max_input_length) : (
self.attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
]
# Ensure that past_key_values tensors can be updated in-place
@ -329,7 +332,8 @@ class CausalLMBatch(Batch):
# And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
]
elif batch.past_key_values[0][0].shape == 3:
for layer in batch.past_key_values:
@ -339,7 +343,7 @@ class CausalLMBatch(Batch):
start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
max_input_length - batch.max_input_length
) * len(batch)
first_past_kvs = batches[0].past_key_values
@ -390,7 +394,9 @@ class CausalLMBatch(Batch):
start_index = end_index
padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape)
padded_past_values = first_past_kvs[j][1].new_zeros(
padded_past_values_shape
)
start_index = 0
for batch in batches:
past_values = batch.past_key_values[j][1]
@ -411,7 +417,6 @@ class CausalLMBatch(Batch):
past_key_values.append([padded_past_keys, padded_past_values])
return cls(
batch_id=batches[0].batch_id,
requests=requests,

View File

@ -200,7 +200,8 @@ class Seq2SeqLMBatch(Batch):
)
padding_right_offset = max(
padding_right_offset,
self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens
self.stopping_criterias[idx].max_new_tokens
- self.stopping_criterias[idx].current_tokens,
)
next_token_choosers.append(self.next_token_choosers[idx])
@ -215,16 +216,22 @@ class Seq2SeqLMBatch(Batch):
self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
if self.decoder_attention_mask is not None:
self.decoder_attention_mask = self.decoder_attention_mask[
keep_indices,
-(self.padding_right_offset + max_decoder_input_length):
(self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset,
keep_indices,
-(self.padding_right_offset + max_decoder_input_length) : (
self.decoder_attention_mask.shape[1] - self.padding_right_offset
)
+ padding_right_offset,
]
self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:]
self.encoder_last_hidden_state = self.encoder_last_hidden_state[
keep_indices, -max_input_length:
]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
self.past_key_values = [[t for t in layer] for layer in self.past_key_values]
self.past_key_values = [
[t for t in layer] for layer in self.past_key_values
]
decoder_past_seq_len = max_decoder_input_length - 1
for layer in self.past_key_values:
@ -234,8 +241,8 @@ class Seq2SeqLMBatch(Batch):
layer[3] = layer[3][keep_indices, :, -max_input_length:]
max_tokens = (
len(requests) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens
len(requests) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens
)
self.requests = requests
@ -255,7 +262,6 @@ class Seq2SeqLMBatch(Batch):
return self
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
@ -387,15 +393,17 @@ class Seq2SeqLMBatch(Batch):
# Ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values]
batch.past_key_values = [
[t for t in layer] for layer in batch.past_key_values
]
start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length
- batch.max_input_length
+ max_decoder_input_length
- batch.max_decoder_input_length
max_input_length
- batch.max_input_length
+ max_decoder_input_length
- batch.max_decoder_input_length
) * len(batch)
# Determine shapes for new past kv tensors
@ -435,9 +443,9 @@ class Seq2SeqLMBatch(Batch):
end_index = start_index + len(batch)
# We slice the past keys and values to remove the padding from previous batches
past_seq_len = batch.max_decoder_input_length - 1
padded_past_values[
start_index:end_index, :, -past_seq_len:, :
] = t[:, :, -past_seq_len:, :]
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
:, :, -past_seq_len:, :
]
del t
start_index = end_index
@ -457,13 +465,12 @@ class Seq2SeqLMBatch(Batch):
end_index = start_index + len(batch)
# We slice the past keys and values to remove the padding from previous batches
padded_past_values[
start_index:end_index, :, -batch.max_input_length:, :
] = t[:, :, -batch.max_input_length:, :]
start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length :, :]
del t
start_index = end_index
return cls(
batch_id=batches[0].batch_id,
requests=requests,