From 8e34beed321f4b1718fb7e2b08c17d083a94806c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 24 Apr 2023 07:25:25 +0100 Subject: [PATCH] equivalent changes for seq2seq_lm --- server/tests/models/test_seq2seq_lm.py | 33 +++- .../models/seq2seq_lm.py | 175 ++++++++++-------- 2 files changed, 125 insertions(+), 83 deletions(-) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 79c9e936..65dafa50 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -219,6 +219,19 @@ def test_batch_concatenate( next_batch_1 = default_multi_requests_seq2seq_lm_batch _, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1) + # Copy hidden state because it is removed from the concatenated branches + next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state + next_batch_1_encoder_last_hidden_state = next_batch_1.encoder_last_hidden_state + + # Clone past_key_values before concatenating to compare after, + # because they are removed from the concatenated batches + next_batch_0_past_key_values = [ + [t.clone() for t in layer] for layer in next_batch_0.past_key_values + ] + next_batch_1_past_key_values = [ + [t.clone() for t in layer] for layer in next_batch_1.past_key_values + ] + next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1]) assert next_batch.batch_id == 0 @@ -239,11 +252,11 @@ def test_batch_concatenate( assert torch.equal( next_batch.encoder_last_hidden_state[0], - next_batch_0.encoder_last_hidden_state[0, -2:], + next_batch_0_encoder_last_hidden_state[0, -2:], ) assert torch.equal( next_batch.encoder_last_hidden_state[1:], - next_batch_1.encoder_last_hidden_state[:, -2:], + next_batch_1_encoder_last_hidden_state[:, -2:], ) assert next_batch.input_lengths == [2, 2, 2] @@ -275,24 +288,24 @@ def test_batch_concatenate( ) for i, past in enumerate(next_batch.past_key_values): - assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:, :], past[0][0]) + assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:, :], past[0][0]) assert torch.equal( - next_batch_1.past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :] + next_batch_1_past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :] ) - assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:, :], past[1][0]) + assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:, :], past[1][0]) assert torch.equal( - next_batch_1.past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :] + next_batch_1_past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :] ) - assert torch.equal(next_batch_0.past_key_values[i][2][0, :, -2:, :], past[2][0]) + assert torch.equal(next_batch_0_past_key_values[i][2][0, :, -2:, :], past[2][0]) assert torch.equal( - next_batch_1.past_key_values[i][2][:, :, -2:, :], past[2][1:] + next_batch_1_past_key_values[i][2][:, :, -2:, :], past[2][1:] ) - assert torch.equal(next_batch_0.past_key_values[i][3][0, :, -2:, :], past[3][0]) + assert torch.equal(next_batch_0_past_key_values[i][3][0, :, -2:, :], past[3][0]) assert torch.equal( - next_batch_1.past_key_values[i][3][:, :, -2:, :], past[3][1:] + next_batch_1_past_key_values[i][3][:, :, -2:, :], past[3][1:] ) for _ in range(3): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index aa452c70..2252fcfc 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -25,7 +25,7 @@ class Seq2SeqLMBatch(Batch): requests_idx_mapping: Dict[int, int] # Encoder values - input_ids: torch.Tensor + input_ids: Optional[torch.Tensor] attention_mask: torch.Tensor # Decoder values @@ -164,6 +164,7 @@ class Seq2SeqLMBatch(Batch): max_input_length = 0 max_decoder_input_length = 0 + padding_right_offset = 0 for i, r in enumerate(requests): idx = self.requests_idx_mapping[r.id] @@ -184,45 +185,53 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length = max( max_decoder_input_length, request_decoder_input_length ) + padding_right_offset = max( + padding_right_offset, + self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens + ) next_token_choosers.append(self.next_token_choosers[idx]) stopping_criterias.append(self.stopping_criterias[idx]) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - decoder_input_ids = self.decoder_input_ids[keep_indices] - attention_mask = self.attention_mask[keep_indices] + self.decoder_input_ids = self.decoder_input_ids[keep_indices] + self.attention_mask = self.attention_mask[keep_indices, -max_input_length:] if self.decoder_attention_mask is not None: - decoder_attention_mask = self.decoder_attention_mask[keep_indices] - else: - decoder_attention_mask = None + self.decoder_attention_mask = self.decoder_attention_mask[ + keep_indices, + -(self.padding_right_offset + max_decoder_input_length): + (self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset, + ] - encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices] + self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:] - past_key_values = [ - [t[keep_indices] for t in layer] for layer in self.past_key_values - ] + # Ensure that past_key_values tensors can be updated in-place + if type(self.past_key_values[0]) == tuple: + self.past_key_values = [[t for t in layer] for layer in self.past_key_values] + + decoder_past_seq_len = max_decoder_input_length - 1 + for layer in self.past_key_values: + layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:] + layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:] + layer[2] = layer[2][keep_indices, :, -max_input_length:] + layer[3] = layer[3][keep_indices, :, -max_input_length:] + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = None + self.all_decoder_input_ids = all_decoder_input_ids + self.input_lengths = input_lengths + self.decoder_input_lengths = decoder_input_lengths + self.offsets = offsets + self.token_offsets = token_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.max_input_length = max_input_length + self.max_decoder_input_length = max_decoder_input_length + self.padding_right_offset = padding_right_offset + + return self - return Seq2SeqLMBatch( - batch_id=self.batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=None, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - all_decoder_input_ids=all_decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_last_hidden_state=encoder_last_hidden_state, - past_key_values=past_key_values, - input_lengths=input_lengths, - decoder_input_lengths=decoder_input_lengths, - offsets=offsets, - token_offsets=token_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length, - max_decoder_input_length=max_decoder_input_length, - padding_right_offset=self.padding_right_offset, - ) @classmethod @tracer.start_as_current_span("concatenate") @@ -350,58 +359,78 @@ class Seq2SeqLMBatch(Batch): encoder_last_hidden_state[ start_index:end_index, -batch.max_input_length :, : ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :] + batch.encoder_last_hidden_state = None - # Iterate over attention layers - for j, past in enumerate(batch.past_key_values): - _, num_heads, _, head_dim = past[0].shape + # Ensure that we can update tensors in-place + if type(batch.past_key_values[0]) == tuple: + batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values] - # This will run only once per layer - if j == len(past_key_values): - past_key_values.append([]) + start_index = end_index - # Decoder past - for k, t in enumerate(past[:2]): - padded_t_shape = ( - total_batch_size, - num_heads, - (max_decoder_input_length - 1), - head_dim, - ) + # Determine shapes for new past kv tensors + first_past_kvs = batches[0].past_key_values + _, num_heads, _, head_dim = first_past_kvs[0][0].shape - # Initialize tensors - # This will run only once per layer and per past tensor - if k == len(past_key_values[j]): - past_key_values[j].append(t.new_zeros(padded_t_shape)) + padded_dec_t_shape = ( + total_batch_size, + num_heads, + (max_decoder_input_length - 1), + head_dim, + ) + padded_enc_t_shape = ( + total_batch_size, + num_heads, + max_input_length, + head_dim, + ) + + # Iterate over attention layers + for j in range(len(first_past_kvs)): + past_key_values.append([]) + + # Decoder past + for k in range(0, 2): + # Initialize tensors + padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape) + past_key_values[j].append(padded_past_values) + + start_index = 0 + for batch in batches: + t = batch.past_key_values[j][k] + # Clear reference to the original tensor + batch.past_key_values[j][k] = None + # Slicing end index for this batch + end_index = start_index + len(batch) # We slice the past keys and values to remove the padding from previous batches - past_key_values[j][k][ - start_index:end_index, - :, - -(batch.max_decoder_input_length - 1) :, - :, - ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :] + past_seq_len = batch.max_decoder_input_length - 1 + padded_past_values[ + start_index:end_index, :, -past_seq_len:, : + ] = t[:, :, -past_seq_len:, :] + del t - # encoder past - for k, t in enumerate(past[2:]): - padded_t_shape = ( - total_batch_size, - num_heads, - max_input_length, - head_dim, - ) + start_index = end_index - idx = k + 2 + # Encoder past + for k in range(2, 4): + # Initialize tensors + padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape) + past_key_values[j].append(padded_past_values) - # Initialize tensors - # This will run only once per layer and per past tensor - if idx == len(past_key_values[j]): - past_key_values[j].append(t.new_zeros(padded_t_shape)) + start_index = 0 + for batch in batches: + t = batch.past_key_values[j][k] + # Clear reference to the original tensor + batch.past_key_values[j][k] = None + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past keys and values to remove the padding from previous batches + padded_past_values[ + start_index:end_index, :, -batch.max_input_length:, : + ] = t[:, :, -batch.max_input_length:, :] + del t - past_key_values[j][idx][ - start_index:end_index, :, -batch.max_input_length :, : - ] = t[:, :, -batch.max_input_length :, :] - - start_index += len(batch) + start_index = end_index return cls( batch_id=batches[0].batch_id,