mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
black
This commit is contained in:
parent
885411e747
commit
889897fe69
@ -181,9 +181,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
|
||||
for _ in range(
|
||||
stopping_criterias[0].max_new_tokens
|
||||
- stopping_criterias[1].max_new_tokens
|
||||
- 1
|
||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||
):
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
@ -174,14 +174,14 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
# Copy stopping_criterias before filtering
|
||||
stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||
stopping_criterias = (
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
|
||||
for _ in range(
|
||||
stopping_criterias[0].max_new_tokens
|
||||
- stopping_criterias[1].max_new_tokens
|
||||
- 1
|
||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||
):
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
@ -178,11 +178,12 @@ class CausalLMBatch(Batch):
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
remaining_decode_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
remaining_decode_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
total_remaining_decode_tokens += remaining_decode_tokens
|
||||
new_padding_right_offset = max(
|
||||
new_padding_right_offset,
|
||||
remaining_decode_tokens
|
||||
new_padding_right_offset, remaining_decode_tokens
|
||||
)
|
||||
|
||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||
@ -190,8 +191,10 @@ class CausalLMBatch(Batch):
|
||||
position_ids = self.position_ids[keep_indices]
|
||||
self.attention_mask = self.attention_mask[
|
||||
keep_indices,
|
||||
-(self.padding_right_offset + max_input_length):
|
||||
(self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset,
|
||||
-(self.padding_right_offset + max_input_length) : (
|
||||
self.attention_mask.shape[1] - self.padding_right_offset
|
||||
)
|
||||
+ new_padding_right_offset,
|
||||
]
|
||||
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
@ -329,7 +332,8 @@ class CausalLMBatch(Batch):
|
||||
# And ensure that we can update tensors in-place
|
||||
if type(batch.past_key_values[0]) == tuple:
|
||||
batch.past_key_values = [
|
||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values
|
||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
|
||||
for layer in batch.past_key_values
|
||||
]
|
||||
elif batch.past_key_values[0][0].shape == 3:
|
||||
for layer in batch.past_key_values:
|
||||
@ -339,7 +343,7 @@ class CausalLMBatch(Batch):
|
||||
start_index = end_index
|
||||
# Add eventual padding tokens that were added while concatenating
|
||||
max_tokens += batch.max_tokens + (
|
||||
max_input_length - batch.max_input_length
|
||||
max_input_length - batch.max_input_length
|
||||
) * len(batch)
|
||||
|
||||
first_past_kvs = batches[0].past_key_values
|
||||
@ -390,7 +394,9 @@ class CausalLMBatch(Batch):
|
||||
|
||||
start_index = end_index
|
||||
|
||||
padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape)
|
||||
padded_past_values = first_past_kvs[j][1].new_zeros(
|
||||
padded_past_values_shape
|
||||
)
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
past_values = batch.past_key_values[j][1]
|
||||
@ -411,7 +417,6 @@ class CausalLMBatch(Batch):
|
||||
|
||||
past_key_values.append([padded_past_keys, padded_past_values])
|
||||
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
|
@ -200,7 +200,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
)
|
||||
padding_right_offset = max(
|
||||
padding_right_offset,
|
||||
self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens
|
||||
self.stopping_criterias[idx].max_new_tokens
|
||||
- self.stopping_criterias[idx].current_tokens,
|
||||
)
|
||||
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
@ -215,16 +216,22 @@ class Seq2SeqLMBatch(Batch):
|
||||
self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
|
||||
if self.decoder_attention_mask is not None:
|
||||
self.decoder_attention_mask = self.decoder_attention_mask[
|
||||
keep_indices,
|
||||
-(self.padding_right_offset + max_decoder_input_length):
|
||||
(self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset,
|
||||
keep_indices,
|
||||
-(self.padding_right_offset + max_decoder_input_length) : (
|
||||
self.decoder_attention_mask.shape[1] - self.padding_right_offset
|
||||
)
|
||||
+ padding_right_offset,
|
||||
]
|
||||
|
||||
self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:]
|
||||
self.encoder_last_hidden_state = self.encoder_last_hidden_state[
|
||||
keep_indices, -max_input_length:
|
||||
]
|
||||
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
if type(self.past_key_values[0]) == tuple:
|
||||
self.past_key_values = [[t for t in layer] for layer in self.past_key_values]
|
||||
self.past_key_values = [
|
||||
[t for t in layer] for layer in self.past_key_values
|
||||
]
|
||||
|
||||
decoder_past_seq_len = max_decoder_input_length - 1
|
||||
for layer in self.past_key_values:
|
||||
@ -234,8 +241,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
||||
|
||||
max_tokens = (
|
||||
len(requests) * (max_input_length + max_decoder_input_length)
|
||||
+ remaining_decode_tokens
|
||||
len(requests) * (max_input_length + max_decoder_input_length)
|
||||
+ remaining_decode_tokens
|
||||
)
|
||||
|
||||
self.requests = requests
|
||||
@ -255,7 +262,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
||||
@ -387,15 +393,17 @@ class Seq2SeqLMBatch(Batch):
|
||||
|
||||
# Ensure that we can update tensors in-place
|
||||
if type(batch.past_key_values[0]) == tuple:
|
||||
batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values]
|
||||
batch.past_key_values = [
|
||||
[t for t in layer] for layer in batch.past_key_values
|
||||
]
|
||||
|
||||
start_index = end_index
|
||||
# Add eventual padding tokens that were added while concatenating
|
||||
max_tokens += batch.max_tokens + (
|
||||
max_input_length
|
||||
- batch.max_input_length
|
||||
+ max_decoder_input_length
|
||||
- batch.max_decoder_input_length
|
||||
max_input_length
|
||||
- batch.max_input_length
|
||||
+ max_decoder_input_length
|
||||
- batch.max_decoder_input_length
|
||||
) * len(batch)
|
||||
|
||||
# Determine shapes for new past kv tensors
|
||||
@ -435,9 +443,9 @@ class Seq2SeqLMBatch(Batch):
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
past_seq_len = batch.max_decoder_input_length - 1
|
||||
padded_past_values[
|
||||
start_index:end_index, :, -past_seq_len:, :
|
||||
] = t[:, :, -past_seq_len:, :]
|
||||
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
|
||||
:, :, -past_seq_len:, :
|
||||
]
|
||||
del t
|
||||
|
||||
start_index = end_index
|
||||
@ -457,13 +465,12 @@ class Seq2SeqLMBatch(Batch):
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
padded_past_values[
|
||||
start_index:end_index, :, -batch.max_input_length:, :
|
||||
] = t[:, :, -batch.max_input_length:, :]
|
||||
start_index:end_index, :, -batch.max_input_length :, :
|
||||
] = t[:, :, -batch.max_input_length :, :]
|
||||
del t
|
||||
|
||||
start_index = end_index
|
||||
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
|
Loading…
Reference in New Issue
Block a user