From 12326eff626b7e344234610243b4e0912c6e2c5f Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 20 Apr 2023 14:54:01 -0700 Subject: [PATCH] feat(server): reduce memory requirement --- .../models/causal_lm.py | 189 +++++++++++------- 1 file changed, 119 insertions(+), 70 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 98313253..cc89049e 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -150,6 +150,8 @@ class CausalLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] + new_padding_right_offset = 0 + for i, r in enumerate(requests): idx = self.requests_idx_mapping[r.id] requests_idx_mapping[r.id] = i @@ -164,36 +166,57 @@ class CausalLMBatch(Batch): max_input_length = max(max_input_length, request_input_length) next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criterias.append(self.stopping_criterias[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + + new_padding_right_offset = max( + new_padding_right_offset, + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached input_ids = self.input_ids[keep_indices] - attention_mask = self.attention_mask[keep_indices] position_ids = self.position_ids[keep_indices] - # Force past to be of dim [self_size, num_heads, ...] for easy indexing - past_key_values = [ - [t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer] - for layer in self.past_key_values + self.attention_mask = self.attention_mask[ + keep_indices, + -(self.padding_right_offset + self.max_input_length): + (self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset, ] - return CausalLMBatch( - batch_id=self.batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - offsets=offsets, - token_offsets=token_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length, - padding_right_offset=self.padding_right_offset, - keys_head_dim_last=self.keys_head_dim_last, - ) + # Ensure that past_key_values tensors can be updated in-place + if type(self.past_key_values[0]) == tuple: + self.past_key_values = [list(layer) for layer in self.past_key_values] + + # Update tensors in-place to allow incremental garbage collection + past_kv_length = self.max_input_length - 1 + for layer in self.past_key_values: + past_keys, past_values = layer + if len(past_keys.shape) == 3: + # Force past to be of dim [self_size, num_heads, ...] for easy indexing + past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) + past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) + if self.keys_head_dim_last: + layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] + else: + layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] + del past_keys + layer[1] = past_values[keep_indices, :, -past_kv_length:, :] + del past_values + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.position_ids = position_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.offsets = offsets + self.token_offsets = token_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + + return self @classmethod @tracer.start_as_current_span("concatenate") @@ -285,62 +308,88 @@ class CausalLMBatch(Batch): position_ids = batch.position_ids.new_empty((total_batch_size, 1)) position_ids[start_index:end_index] = batch.position_ids - for j, past in enumerate(batch.past_key_values): - past_keys, past_values = past + # Shenanigans to get dimensions because BLOOM outputs a past with a different shape + # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] + # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] + # And ensure that we can update tensors in-place + if type(batch.past_key_values[0]) == tuple: + batch.past_key_values = [ + [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values + ] + elif batch.past_key_values[0][0].shape == 3: + for layer in batch.past_key_values: + for k, t in enumerate(layer): + layer[k] = t.view(len(batch), -1, *t.shape[-2:]) - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - past_keys = past_keys.view(len(batch), -1, *past_keys.shape[-2:]) - past_values = past_values.view(len(batch), -1, *past_values.shape[-2:]) + start_index = end_index - _, num_heads, padded_sequence_length, head_dim = past_values.shape + first_past_kvs = batches[0].past_key_values + _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape - padded_past_values_shape = ( - total_batch_size, - num_heads, - max_input_length - 1, - head_dim, - ) + padded_past_values_shape = ( + total_batch_size, + num_heads, + max_input_length - 1, + head_dim, + ) + if batches[0].keys_head_dim_last: + padded_past_keys_shape = padded_past_values_shape + else: + # seq_length is last for BLOOM + padded_past_keys_shape = ( + total_batch_size, + num_heads, + head_dim, + max_input_length - 1, + ) + + # Iterate over attention layers + # Concatenate past key values layer by layer to allow incremental garbage collection + for j in range(len(first_past_kvs)): + padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) + start_index = 0 + for batch in batches: + past_keys = batch.past_key_values[j][0] + # Clear reference to the original tensor + batch.past_key_values[j][0] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the keys to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 if batch.keys_head_dim_last: - padded_past_keys_shape = padded_past_values_shape + padded_past_keys[ + start_index:end_index, :, -past_seq_len:, : + ] = past_keys[:, :, -past_seq_len:, :] else: - # seq_length is last for BLOOM - padded_past_keys_shape = ( - total_batch_size, - num_heads, - head_dim, - max_input_length - 1, - ) + # BLOOM case + padded_past_keys[ + start_index:end_index, :, :, -past_seq_len: + ] = past_keys[:, :, :, -past_seq_len:] + del past_keys - # This will run only once per layer - if j == len(past_key_values): - padded_past_keys = past_keys.new_zeros(padded_past_keys_shape) - padded_past_values = past_values.new_zeros(padded_past_values_shape) - past_key_values.append((padded_past_keys, padded_past_values)) + start_index = end_index - # We slice the past keys and values to remove the padding from previous batches - if batch.keys_head_dim_last: - past_key_values[j][0][ - start_index:end_index, - :, - -(batch.max_input_length - 1) :, - :, - ] = past_keys[:, :, -(batch.max_input_length - 1) :, :] - else: - past_key_values[j][0][ - start_index:end_index, - :, - :, - -(batch.max_input_length - 1) :, - ] = past_keys[:, :, :, -(batch.max_input_length - 1) :] + padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape) + start_index = 0 + for batch in batches: + past_values = batch.past_key_values[j][1] + # Clear reference to the original tensor + batch.past_key_values[j][1] = None - past_key_values[j][1][ - start_index:end_index, :, -(batch.max_input_length - 1) :, : - ] = past_values[:, :, -(batch.max_input_length - 1) :, :] + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past values to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + padded_past_values[ + start_index:end_index, :, -past_seq_len:, : + ] = past_values[:, :, -past_seq_len:, :] + del past_values - start_index += len(batch) + start_index = end_index + + past_key_values.append([padded_past_keys, padded_past_values]) return cls( batch_id=batches[0].batch_id,