mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
update unit tests
This commit is contained in:
parent
12326eff62
commit
0b1d0010a4
@ -175,12 +175,14 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
generations[1].generated_text.generated_tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
# Copy stopping_criterias before filtering
|
||||
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
stopping_criterias[0].max_new_tokens
|
||||
- stopping_criterias[1].max_new_tokens
|
||||
- 1
|
||||
):
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
@ -212,6 +214,15 @@ def test_batch_concatenate(
|
||||
next_batch_1 = default_multi_requests_bloom_batch
|
||||
_, next_batch_1 = default_bloom.generate_token(next_batch_1)
|
||||
|
||||
# Clone past_key_values before concatenating to compare after,
|
||||
# because they are removed from the concatenated batches
|
||||
next_batch_0_past_key_values = [
|
||||
(k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values
|
||||
]
|
||||
next_batch_1_past_key_values = [
|
||||
(k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values
|
||||
]
|
||||
|
||||
next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1])
|
||||
|
||||
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
|
||||
@ -246,15 +257,15 @@ def test_batch_concatenate(
|
||||
assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values])
|
||||
|
||||
for i, past in enumerate(next_batch.past_key_values):
|
||||
assert torch.equal(next_batch_0.past_key_values[i][0][:, :, -2:], past[0][0])
|
||||
assert torch.equal(next_batch_0_past_key_values[i][0][:, :, -2:], past[0][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][0][:, :, -1:],
|
||||
next_batch_1_past_key_values[i][0][:, :, -1:],
|
||||
past[0][1:, :, :, -1].reshape(-1, 64, 1),
|
||||
)
|
||||
|
||||
assert torch.equal(next_batch_0.past_key_values[i][1][:, -2:, :], past[1][0])
|
||||
assert torch.equal(next_batch_0_past_key_values[i][1][:, -2:, :], past[1][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][1][:, -1:, :],
|
||||
next_batch_1_past_key_values[i][1][:, -1:, :],
|
||||
past[1][1:, :, -1, :].reshape(-1, 1, 64),
|
||||
)
|
||||
|
||||
|
@ -173,12 +173,14 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
generations[1].generated_text.generated_tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
# Copy stopping_criterias before filtering
|
||||
stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
stopping_criterias[0].max_new_tokens
|
||||
- stopping_criterias[1].max_new_tokens
|
||||
- 1
|
||||
):
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
@ -209,6 +211,15 @@ def test_batch_concatenate(
|
||||
next_batch_1 = default_multi_requests_causal_lm_batch
|
||||
_, next_batch_1 = default_causal_lm.generate_token(next_batch_1)
|
||||
|
||||
# Clone past_key_values before concatenating to compare after,
|
||||
# because they are removed from the concatenated batches
|
||||
next_batch_0_past_key_values = [
|
||||
(k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values
|
||||
]
|
||||
next_batch_1_past_key_values = [
|
||||
(k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values
|
||||
]
|
||||
|
||||
next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1])
|
||||
|
||||
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
|
||||
@ -244,14 +255,14 @@ def test_batch_concatenate(
|
||||
assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
|
||||
|
||||
for i, past in enumerate(next_batch.past_key_values):
|
||||
assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:], past[0][0])
|
||||
assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:], past[0][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]
|
||||
next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]
|
||||
)
|
||||
|
||||
assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:], past[1][0])
|
||||
assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:], past[1][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]
|
||||
next_batch_1_past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
|
Loading…
Reference in New Issue
Block a user