From 889897fe69499487f81638cb573e58487d2674cd Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 24 Apr 2023 16:08:50 +0200 Subject: [PATCH] black --- server/tests/models/test_bloom.py | 4 +- server/tests/models/test_causal_lm.py | 8 ++-- .../models/causal_lm.py | 23 +++++---- .../models/seq2seq_lm.py | 47 +++++++++++-------- 4 files changed, 46 insertions(+), 36 deletions(-) diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 47d701eb..f0adab97 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -181,9 +181,7 @@ def test_causal_lm_generate_token_completion_multi( next_batch = next_batch.filter([next_batch.requests[0]]) for _ in range( - stopping_criterias[0].max_new_tokens - - stopping_criterias[1].max_new_tokens - - 1 + stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 ): generations, next_batch = default_bloom.generate_token(next_batch) assert len(generations) == len(next_batch) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 03d3ef9b..f1f13e4b 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -174,14 +174,14 @@ def test_causal_lm_generate_token_completion_multi( == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) # Copy stopping_criterias before filtering - stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy() + stopping_criterias = ( + default_multi_requests_causal_lm_batch.stopping_criterias.copy() + ) next_batch = next_batch.filter([next_batch.requests[0]]) for _ in range( - stopping_criterias[0].max_new_tokens - - stopping_criterias[1].max_new_tokens - - 1 + stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 ): generations, next_batch = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index c9650b39..336c9823 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -178,11 +178,12 @@ class CausalLMBatch(Batch): next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) - remaining_decode_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) total_remaining_decode_tokens += remaining_decode_tokens new_padding_right_offset = max( - new_padding_right_offset, - remaining_decode_tokens + new_padding_right_offset, remaining_decode_tokens ) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached @@ -190,8 +191,10 @@ class CausalLMBatch(Batch): position_ids = self.position_ids[keep_indices] self.attention_mask = self.attention_mask[ keep_indices, - -(self.padding_right_offset + max_input_length): - (self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset, + -(self.padding_right_offset + max_input_length) : ( + self.attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, ] # Ensure that past_key_values tensors can be updated in-place @@ -329,7 +332,8 @@ class CausalLMBatch(Batch): # And ensure that we can update tensors in-place if type(batch.past_key_values[0]) == tuple: batch.past_key_values = [ - [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values + [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] + for layer in batch.past_key_values ] elif batch.past_key_values[0][0].shape == 3: for layer in batch.past_key_values: @@ -339,7 +343,7 @@ class CausalLMBatch(Batch): start_index = end_index # Add eventual padding tokens that were added while concatenating max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length + max_input_length - batch.max_input_length ) * len(batch) first_past_kvs = batches[0].past_key_values @@ -390,7 +394,9 @@ class CausalLMBatch(Batch): start_index = end_index - padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape) + padded_past_values = first_past_kvs[j][1].new_zeros( + padded_past_values_shape + ) start_index = 0 for batch in batches: past_values = batch.past_key_values[j][1] @@ -411,7 +417,6 @@ class CausalLMBatch(Batch): past_key_values.append([padded_past_keys, padded_past_values]) - return cls( batch_id=batches[0].batch_id, requests=requests, diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index e2becb6f..0cb20760 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -200,7 +200,8 @@ class Seq2SeqLMBatch(Batch): ) padding_right_offset = max( padding_right_offset, - self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens + self.stopping_criterias[idx].max_new_tokens + - self.stopping_criterias[idx].current_tokens, ) next_token_choosers.append(self.next_token_choosers[idx]) @@ -215,16 +216,22 @@ class Seq2SeqLMBatch(Batch): self.attention_mask = self.attention_mask[keep_indices, -max_input_length:] if self.decoder_attention_mask is not None: self.decoder_attention_mask = self.decoder_attention_mask[ - keep_indices, - -(self.padding_right_offset + max_decoder_input_length): - (self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset, + keep_indices, + -(self.padding_right_offset + max_decoder_input_length) : ( + self.decoder_attention_mask.shape[1] - self.padding_right_offset + ) + + padding_right_offset, ] - self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:] + self.encoder_last_hidden_state = self.encoder_last_hidden_state[ + keep_indices, -max_input_length: + ] # Ensure that past_key_values tensors can be updated in-place if type(self.past_key_values[0]) == tuple: - self.past_key_values = [[t for t in layer] for layer in self.past_key_values] + self.past_key_values = [ + [t for t in layer] for layer in self.past_key_values + ] decoder_past_seq_len = max_decoder_input_length - 1 for layer in self.past_key_values: @@ -234,8 +241,8 @@ class Seq2SeqLMBatch(Batch): layer[3] = layer[3][keep_indices, :, -max_input_length:] max_tokens = ( - len(requests) * (max_input_length + max_decoder_input_length) - + remaining_decode_tokens + len(requests) * (max_input_length + max_decoder_input_length) + + remaining_decode_tokens ) self.requests = requests @@ -255,7 +262,6 @@ class Seq2SeqLMBatch(Batch): return self - @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": @@ -387,15 +393,17 @@ class Seq2SeqLMBatch(Batch): # Ensure that we can update tensors in-place if type(batch.past_key_values[0]) == tuple: - batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values] + batch.past_key_values = [ + [t for t in layer] for layer in batch.past_key_values + ] start_index = end_index # Add eventual padding tokens that were added while concatenating max_tokens += batch.max_tokens + ( - max_input_length - - batch.max_input_length - + max_decoder_input_length - - batch.max_decoder_input_length + max_input_length + - batch.max_input_length + + max_decoder_input_length + - batch.max_decoder_input_length ) * len(batch) # Determine shapes for new past kv tensors @@ -435,9 +443,9 @@ class Seq2SeqLMBatch(Batch): end_index = start_index + len(batch) # We slice the past keys and values to remove the padding from previous batches past_seq_len = batch.max_decoder_input_length - 1 - padded_past_values[ - start_index:end_index, :, -past_seq_len:, : - ] = t[:, :, -past_seq_len:, :] + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[ + :, :, -past_seq_len:, : + ] del t start_index = end_index @@ -457,13 +465,12 @@ class Seq2SeqLMBatch(Batch): end_index = start_index + len(batch) # We slice the past keys and values to remove the padding from previous batches padded_past_values[ - start_index:end_index, :, -batch.max_input_length:, : - ] = t[:, :, -batch.max_input_length:, :] + start_index:end_index, :, -batch.max_input_length :, : + ] = t[:, :, -batch.max_input_length :, :] del t start_index = end_index - return cls( batch_id=batches[0].batch_id, requests=requests,