mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
black
This commit is contained in:
parent
885411e747
commit
889897fe69
@ -181,9 +181,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
stopping_criterias[0].max_new_tokens
|
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||||
- stopping_criterias[1].max_new_tokens
|
|
||||||
- 1
|
|
||||||
):
|
):
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
@ -174,14 +174,14 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
# Copy stopping_criterias before filtering
|
# Copy stopping_criterias before filtering
|
||||||
stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
stopping_criterias = (
|
||||||
|
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||||
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
stopping_criterias[0].max_new_tokens
|
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||||
- stopping_criterias[1].max_new_tokens
|
|
||||||
- 1
|
|
||||||
):
|
):
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
@ -178,11 +178,12 @@ class CausalLMBatch(Batch):
|
|||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
remaining_decode_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
remaining_decode_tokens = (
|
||||||
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
|
)
|
||||||
total_remaining_decode_tokens += remaining_decode_tokens
|
total_remaining_decode_tokens += remaining_decode_tokens
|
||||||
new_padding_right_offset = max(
|
new_padding_right_offset = max(
|
||||||
new_padding_right_offset,
|
new_padding_right_offset, remaining_decode_tokens
|
||||||
remaining_decode_tokens
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||||
@ -190,8 +191,10 @@ class CausalLMBatch(Batch):
|
|||||||
position_ids = self.position_ids[keep_indices]
|
position_ids = self.position_ids[keep_indices]
|
||||||
self.attention_mask = self.attention_mask[
|
self.attention_mask = self.attention_mask[
|
||||||
keep_indices,
|
keep_indices,
|
||||||
-(self.padding_right_offset + max_input_length):
|
-(self.padding_right_offset + max_input_length) : (
|
||||||
(self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset,
|
self.attention_mask.shape[1] - self.padding_right_offset
|
||||||
|
)
|
||||||
|
+ new_padding_right_offset,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Ensure that past_key_values tensors can be updated in-place
|
# Ensure that past_key_values tensors can be updated in-place
|
||||||
@ -329,7 +332,8 @@ class CausalLMBatch(Batch):
|
|||||||
# And ensure that we can update tensors in-place
|
# And ensure that we can update tensors in-place
|
||||||
if type(batch.past_key_values[0]) == tuple:
|
if type(batch.past_key_values[0]) == tuple:
|
||||||
batch.past_key_values = [
|
batch.past_key_values = [
|
||||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values
|
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
|
||||||
|
for layer in batch.past_key_values
|
||||||
]
|
]
|
||||||
elif batch.past_key_values[0][0].shape == 3:
|
elif batch.past_key_values[0][0].shape == 3:
|
||||||
for layer in batch.past_key_values:
|
for layer in batch.past_key_values:
|
||||||
@ -390,7 +394,9 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
|
|
||||||
padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape)
|
padded_past_values = first_past_kvs[j][1].new_zeros(
|
||||||
|
padded_past_values_shape
|
||||||
|
)
|
||||||
start_index = 0
|
start_index = 0
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
past_values = batch.past_key_values[j][1]
|
past_values = batch.past_key_values[j][1]
|
||||||
@ -411,7 +417,6 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
past_key_values.append([padded_past_keys, padded_past_values])
|
past_key_values.append([padded_past_keys, padded_past_values])
|
||||||
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
|
@ -200,7 +200,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
padding_right_offset,
|
padding_right_offset,
|
||||||
self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens
|
self.stopping_criterias[idx].max_new_tokens
|
||||||
|
- self.stopping_criterias[idx].current_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
@ -216,15 +217,21 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
if self.decoder_attention_mask is not None:
|
if self.decoder_attention_mask is not None:
|
||||||
self.decoder_attention_mask = self.decoder_attention_mask[
|
self.decoder_attention_mask = self.decoder_attention_mask[
|
||||||
keep_indices,
|
keep_indices,
|
||||||
-(self.padding_right_offset + max_decoder_input_length):
|
-(self.padding_right_offset + max_decoder_input_length) : (
|
||||||
(self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset,
|
self.decoder_attention_mask.shape[1] - self.padding_right_offset
|
||||||
|
)
|
||||||
|
+ padding_right_offset,
|
||||||
]
|
]
|
||||||
|
|
||||||
self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:]
|
self.encoder_last_hidden_state = self.encoder_last_hidden_state[
|
||||||
|
keep_indices, -max_input_length:
|
||||||
|
]
|
||||||
|
|
||||||
# Ensure that past_key_values tensors can be updated in-place
|
# Ensure that past_key_values tensors can be updated in-place
|
||||||
if type(self.past_key_values[0]) == tuple:
|
if type(self.past_key_values[0]) == tuple:
|
||||||
self.past_key_values = [[t for t in layer] for layer in self.past_key_values]
|
self.past_key_values = [
|
||||||
|
[t for t in layer] for layer in self.past_key_values
|
||||||
|
]
|
||||||
|
|
||||||
decoder_past_seq_len = max_decoder_input_length - 1
|
decoder_past_seq_len = max_decoder_input_length - 1
|
||||||
for layer in self.past_key_values:
|
for layer in self.past_key_values:
|
||||||
@ -255,7 +262,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
||||||
@ -387,7 +393,9 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
# Ensure that we can update tensors in-place
|
# Ensure that we can update tensors in-place
|
||||||
if type(batch.past_key_values[0]) == tuple:
|
if type(batch.past_key_values[0]) == tuple:
|
||||||
batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values]
|
batch.past_key_values = [
|
||||||
|
[t for t in layer] for layer in batch.past_key_values
|
||||||
|
]
|
||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
# Add eventual padding tokens that were added while concatenating
|
# Add eventual padding tokens that were added while concatenating
|
||||||
@ -435,9 +443,9 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
end_index = start_index + len(batch)
|
end_index = start_index + len(batch)
|
||||||
# We slice the past keys and values to remove the padding from previous batches
|
# We slice the past keys and values to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_decoder_input_length - 1
|
past_seq_len = batch.max_decoder_input_length - 1
|
||||||
padded_past_values[
|
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
:, :, -past_seq_len:, :
|
||||||
] = t[:, :, -past_seq_len:, :]
|
]
|
||||||
del t
|
del t
|
||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
@ -463,7 +471,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
|
Loading…
Reference in New Issue
Block a user