diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ca8fccfa..7dc7fb85 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -340,12 +340,13 @@ class CausalLMBatch(Batch): for k, t in enumerate(layer): layer[k] = t.view(len(batch), -1, *t.shape[-2:]) - start_index = end_index # Add eventual padding tokens that were added while concatenating max_tokens += batch.max_tokens + ( max_input_length - batch.max_input_length ) * len(batch) + start_index = end_index + first_past_kvs = batches[0].past_key_values _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 0cb20760..4ac5ed3c 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -177,7 +177,7 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length = 0 padding_right_offset = 0 - remaining_decode_tokens = 0 + total_remaining_decode_tokens = 0 for i, r in enumerate(requests): idx = self.requests_idx_mapping[r.id] @@ -198,18 +198,15 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length = max( max_decoder_input_length, request_decoder_input_length ) - padding_right_offset = max( - padding_right_offset, - self.stopping_criterias[idx].max_new_tokens - - self.stopping_criterias[idx].current_tokens, - ) next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) - remaining_decode_tokens += ( + remaining_decode_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) + total_remaining_decode_tokens += remaining_decode_tokens + padding_right_offset = max(padding_right_offset, remaining_decode_tokens) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached self.decoder_input_ids = self.decoder_input_ids[keep_indices] @@ -397,7 +394,6 @@ class Seq2SeqLMBatch(Batch): [t for t in layer] for layer in batch.past_key_values ] - start_index = end_index # Add eventual padding tokens that were added while concatenating max_tokens += batch.max_tokens + ( max_input_length @@ -406,6 +402,8 @@ class Seq2SeqLMBatch(Batch): - batch.max_decoder_input_length ) * len(batch) + start_index = end_index + # Determine shapes for new past kv tensors first_past_kvs = batches[0].past_key_values _, num_heads, _, head_dim = first_past_kvs[0][0].shape