diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index de0ef57b..47d701eb 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -175,12 +175,14 @@ def test_causal_lm_generate_token_completion_multi( generations[1].generated_text.generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) + # Copy stopping_criterias before filtering + stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy() next_batch = next_batch.filter([next_batch.requests[0]]) for _ in range( - default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens - - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + stopping_criterias[0].max_new_tokens + - stopping_criterias[1].max_new_tokens - 1 ): generations, next_batch = default_bloom.generate_token(next_batch) @@ -212,6 +214,15 @@ def test_batch_concatenate( next_batch_1 = default_multi_requests_bloom_batch _, next_batch_1 = default_bloom.generate_token(next_batch_1) + # Clone past_key_values before concatenating to compare after, + # because they are removed from the concatenated batches + next_batch_0_past_key_values = [ + (k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values + ] + next_batch_1_past_key_values = [ + (k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values + ] + next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1]) assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) @@ -246,15 +257,15 @@ def test_batch_concatenate( assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values]) for i, past in enumerate(next_batch.past_key_values): - assert torch.equal(next_batch_0.past_key_values[i][0][:, :, -2:], past[0][0]) + assert torch.equal(next_batch_0_past_key_values[i][0][:, :, -2:], past[0][0]) assert torch.equal( - next_batch_1.past_key_values[i][0][:, :, -1:], + next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, :, -1].reshape(-1, 64, 1), ) - assert torch.equal(next_batch_0.past_key_values[i][1][:, -2:, :], past[1][0]) + assert torch.equal(next_batch_0_past_key_values[i][1][:, -2:, :], past[1][0]) assert torch.equal( - next_batch_1.past_key_values[i][1][:, -1:, :], + next_batch_1_past_key_values[i][1][:, -1:, :], past[1][1:, :, -1, :].reshape(-1, 1, 64), ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index ad79a4ca..03d3ef9b 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -173,12 +173,14 @@ def test_causal_lm_generate_token_completion_multi( generations[1].generated_text.generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) + # Copy stopping_criterias before filtering + stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy() next_batch = next_batch.filter([next_batch.requests[0]]) for _ in range( - default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens - - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + stopping_criterias[0].max_new_tokens + - stopping_criterias[1].max_new_tokens - 1 ): generations, next_batch = default_causal_lm.generate_token(next_batch) @@ -209,6 +211,15 @@ def test_batch_concatenate( next_batch_1 = default_multi_requests_causal_lm_batch _, next_batch_1 = default_causal_lm.generate_token(next_batch_1) + # Clone past_key_values before concatenating to compare after, + # because they are removed from the concatenated batches + next_batch_0_past_key_values = [ + (k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values + ] + next_batch_1_past_key_values = [ + (k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values + ] + next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1]) assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) @@ -244,14 +255,14 @@ def test_batch_concatenate( assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values]) for i, past in enumerate(next_batch.past_key_values): - assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:], past[0][0]) + assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:], past[0][0]) assert torch.equal( - next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :] + next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :] ) - assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:], past[1][0]) + assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:], past[1][0]) assert torch.equal( - next_batch_1.past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :] + next_batch_1_past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :] ) for _ in range(