fix(server): Small tidy of code from recent changes (#251)

remaining_decode_tokens was calculated twice in Seq2SeqLMBatch.filter()
This commit is contained in:
Nick Hill 2023-04-27 00:57:28 -07:00 committed by GitHub
parent b4cf832c40
commit 34bca0b8d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 9 deletions

View File

@ -340,12 +340,13 @@ class CausalLMBatch(Batch):
for k, t in enumerate(layer): for k, t in enumerate(layer):
layer[k] = t.view(len(batch), -1, *t.shape[-2:]) layer[k] = t.view(len(batch), -1, *t.shape[-2:])
start_index = end_index
# Add eventual padding tokens that were added while concatenating # Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + ( max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length max_input_length - batch.max_input_length
) * len(batch) ) * len(batch)
start_index = end_index
first_past_kvs = batches[0].past_key_values first_past_kvs = batches[0].past_key_values
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape

View File

@ -177,7 +177,7 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length = 0 max_decoder_input_length = 0
padding_right_offset = 0 padding_right_offset = 0
remaining_decode_tokens = 0 total_remaining_decode_tokens = 0
for i, r in enumerate(requests): for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id] idx = self.requests_idx_mapping[r.id]
@ -198,18 +198,15 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length = max( max_decoder_input_length = max(
max_decoder_input_length, request_decoder_input_length max_decoder_input_length, request_decoder_input_length
) )
padding_right_offset = max(
padding_right_offset,
self.stopping_criterias[idx].max_new_tokens
- self.stopping_criterias[idx].current_tokens,
)
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
remaining_decode_tokens += ( remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
total_remaining_decode_tokens += remaining_decode_tokens
padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
self.decoder_input_ids = self.decoder_input_ids[keep_indices] self.decoder_input_ids = self.decoder_input_ids[keep_indices]
@ -397,7 +394,6 @@ class Seq2SeqLMBatch(Batch):
[t for t in layer] for layer in batch.past_key_values [t for t in layer] for layer in batch.past_key_values
] ]
start_index = end_index
# Add eventual padding tokens that were added while concatenating # Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + ( max_tokens += batch.max_tokens + (
max_input_length max_input_length
@ -406,6 +402,8 @@ class Seq2SeqLMBatch(Batch):
- batch.max_decoder_input_length - batch.max_decoder_input_length
) * len(batch) ) * len(batch)
start_index = end_index
# Determine shapes for new past kv tensors # Determine shapes for new past kv tensors
first_past_kvs = batches[0].past_key_values first_past_kvs = batches[0].past_key_values
_, num_heads, _, head_dim = first_past_kvs[0][0].shape _, num_heads, _, head_dim = first_past_kvs[0][0].shape