update unit tests

This commit is contained in:
Nick Hill 2023-04-24 06:52:46 +01:00
parent 12326eff62
commit 0b1d0010a4
2 changed files with 34 additions and 12 deletions

View File

@ -175,12 +175,14 @@ def test_causal_lm_generate_token_completion_multi(
generations[1].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
stopping_criterias[0].max_new_tokens
- stopping_criterias[1].max_new_tokens
- 1
):
generations, next_batch = default_bloom.generate_token(next_batch)
@ -212,6 +214,15 @@ def test_batch_concatenate(
next_batch_1 = default_multi_requests_bloom_batch
_, next_batch_1 = default_bloom.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values
]
next_batch_1_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values
]
next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1])
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
@ -246,15 +257,15 @@ def test_batch_concatenate(
assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0.past_key_values[i][0][:, :, -2:], past[0][0])
assert torch.equal(next_batch_0_past_key_values[i][0][:, :, -2:], past[0][0])
assert torch.equal(
next_batch_1.past_key_values[i][0][:, :, -1:],
next_batch_1_past_key_values[i][0][:, :, -1:],
past[0][1:, :, :, -1].reshape(-1, 64, 1),
)
assert torch.equal(next_batch_0.past_key_values[i][1][:, -2:, :], past[1][0])
assert torch.equal(next_batch_0_past_key_values[i][1][:, -2:, :], past[1][0])
assert torch.equal(
next_batch_1.past_key_values[i][1][:, -1:, :],
next_batch_1_past_key_values[i][1][:, -1:, :],
past[1][1:, :, -1, :].reshape(-1, 1, 64),
)

View File

@ -173,12 +173,14 @@ def test_causal_lm_generate_token_completion_multi(
generations[1].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy()
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
stopping_criterias[0].max_new_tokens
- stopping_criterias[1].max_new_tokens
- 1
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
@ -209,6 +211,15 @@ def test_batch_concatenate(
next_batch_1 = default_multi_requests_causal_lm_batch
_, next_batch_1 = default_causal_lm.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values
]
next_batch_1_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values
]
next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1])
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
@ -244,14 +255,14 @@ def test_batch_concatenate(
assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:], past[0][0])
assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:], past[0][0])
assert torch.equal(
next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]
next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]
)
assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:], past[1][0])
assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:], past[1][0])
assert torch.equal(
next_batch_1.past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]
next_batch_1_past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]
)
for _ in range(