mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
fix(server): Small tidy of code from recent changes (#251)
remaining_decode_tokens was calculated twice in Seq2SeqLMBatch.filter()
This commit is contained in:
parent
b4cf832c40
commit
34bca0b8d3
@ -340,12 +340,13 @@ class CausalLMBatch(Batch):
|
|||||||
for k, t in enumerate(layer):
|
for k, t in enumerate(layer):
|
||||||
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
|
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
|
||||||
|
|
||||||
start_index = end_index
|
|
||||||
# Add eventual padding tokens that were added while concatenating
|
# Add eventual padding tokens that were added while concatenating
|
||||||
max_tokens += batch.max_tokens + (
|
max_tokens += batch.max_tokens + (
|
||||||
max_input_length - batch.max_input_length
|
max_input_length - batch.max_input_length
|
||||||
) * len(batch)
|
) * len(batch)
|
||||||
|
|
||||||
|
start_index = end_index
|
||||||
|
|
||||||
first_past_kvs = batches[0].past_key_values
|
first_past_kvs = batches[0].past_key_values
|
||||||
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
|
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
|
||||||
|
|
||||||
|
@ -177,7 +177,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
max_decoder_input_length = 0
|
max_decoder_input_length = 0
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
|
|
||||||
remaining_decode_tokens = 0
|
total_remaining_decode_tokens = 0
|
||||||
|
|
||||||
for i, r in enumerate(requests):
|
for i, r in enumerate(requests):
|
||||||
idx = self.requests_idx_mapping[r.id]
|
idx = self.requests_idx_mapping[r.id]
|
||||||
@ -198,18 +198,15 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
max_decoder_input_length = max(
|
max_decoder_input_length = max(
|
||||||
max_decoder_input_length, request_decoder_input_length
|
max_decoder_input_length, request_decoder_input_length
|
||||||
)
|
)
|
||||||
padding_right_offset = max(
|
|
||||||
padding_right_offset,
|
|
||||||
self.stopping_criterias[idx].max_new_tokens
|
|
||||||
- self.stopping_criterias[idx].current_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
remaining_decode_tokens += (
|
remaining_decode_tokens = (
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
|
total_remaining_decode_tokens += remaining_decode_tokens
|
||||||
|
padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
|
||||||
|
|
||||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||||
self.decoder_input_ids = self.decoder_input_ids[keep_indices]
|
self.decoder_input_ids = self.decoder_input_ids[keep_indices]
|
||||||
@ -397,7 +394,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
[t for t in layer] for layer in batch.past_key_values
|
[t for t in layer] for layer in batch.past_key_values
|
||||||
]
|
]
|
||||||
|
|
||||||
start_index = end_index
|
|
||||||
# Add eventual padding tokens that were added while concatenating
|
# Add eventual padding tokens that were added while concatenating
|
||||||
max_tokens += batch.max_tokens + (
|
max_tokens += batch.max_tokens + (
|
||||||
max_input_length
|
max_input_length
|
||||||
@ -406,6 +402,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
- batch.max_decoder_input_length
|
- batch.max_decoder_input_length
|
||||||
) * len(batch)
|
) * len(batch)
|
||||||
|
|
||||||
|
start_index = end_index
|
||||||
|
|
||||||
# Determine shapes for new past kv tensors
|
# Determine shapes for new past kv tensors
|
||||||
first_past_kvs = batches[0].past_key_values
|
first_past_kvs = batches[0].past_key_values
|
||||||
_, num_heads, _, head_dim = first_past_kvs[0][0].shape
|
_, num_heads, _, head_dim = first_past_kvs[0][0].shape
|
||||||
|
Loading…
Reference in New Issue
Block a user