From 5db40a5109a4abda5fb5a449df48229c5e47e936 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 8 Dec 2022 18:36:17 +0100 Subject: [PATCH] black --- server/tests/conftest.py | 4 +- server/tests/models/test_bloom.py | 94 ++++++++++----- server/tests/models/test_causal_lm.py | 130 ++++++++++++++------- server/tests/models/test_seq2seq_lm.py | 128 ++++++++++++++------ server/tests/test_utils.py | 8 +- server/text_generation/models/bloom.py | 26 +++-- server/text_generation/models/causal_lm.py | 56 ++++----- 7 files changed, 299 insertions(+), 147 deletions(-) diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 0dcefae6..0640d45d 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -29,6 +29,8 @@ def gpt2_tokenizer(): @pytest.fixture(scope="session") def mt0_small_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-small", padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + "bigscience/mt0-small", padding_side="left" + ) tokenizer.bos_token_id = 0 return tokenizer diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 1ebf06f4..49dabb14 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -21,16 +21,14 @@ def default_pb_request(default_pb_parameters): @pytest.fixture def default_pb_batch(default_pb_request): - return generate_pb2.Batch( - id=0, - requests=[default_pb_request], - size=1 - ) + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) @pytest.fixture def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer): - return BloomCausalLMBatch.from_pb(default_pb_batch, bloom_560m_tokenizer, torch.device("cpu")) + return BloomCausalLMBatch.from_pb( + default_pb_batch, bloom_560m_tokenizer, torch.device("cpu") + ) @pytest.fixture @@ -40,12 +38,10 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer) req_1.id = 1 req_1.max_new_tokens = 5 - batch_pb = generate_pb2.Batch( - id=0, - requests=[req_0, req_1], - size=2 + batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) + return BloomCausalLMBatch.from_pb( + batch_pb, bloom_560m_tokenizer, torch.device("cpu") ) - return BloomCausalLMBatch.from_pb(batch_pb, bloom_560m_tokenizer, torch.device("cpu")) @pytest.fixture(scope="session") @@ -126,13 +122,20 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch) assert len(generated_texts) == 1 assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" assert generated_texts[0].request == default_bloom_batch.requests[0] - assert generated_texts[0].tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens + assert ( + generated_texts[0].tokens + == default_bloom_batch.stopping_criterias[0].max_new_tokens + ) -def test_causal_lm_generate_token_completion_multi(default_bloom, default_multi_requests_bloom_batch): +def test_causal_lm_generate_token_completion_multi( + default_bloom, default_multi_requests_bloom_batch +): next_batch = default_multi_requests_bloom_batch - for i in range(default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1): + for i in range( + default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 + ): generated_texts, next_batch = default_bloom.generate_token(next_batch) assert generated_texts == [] @@ -142,11 +145,16 @@ def test_causal_lm_generate_token_completion_multi(default_bloom, default_multi_ assert len(generated_texts) == 1 assert generated_texts[0].output == "TestTestTestTestTestTest" assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] - assert generated_texts[0].tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + assert ( + generated_texts[0].tokens + == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + ) for _ in range( - default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens - - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1): + default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + - 1 + ): generated_texts, next_batch = default_bloom.generate_token(next_batch) assert generated_texts == [] @@ -156,10 +164,15 @@ def test_causal_lm_generate_token_completion_multi(default_bloom, default_multi_ assert len(generated_texts) == 1 assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] - assert generated_texts[0].tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + assert ( + generated_texts[0].tokens + == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + ) -def test_batch_concatenate(default_bloom, default_bloom_batch, default_multi_requests_bloom_batch): +def test_batch_concatenate( + default_bloom, default_bloom_batch, default_multi_requests_bloom_batch +): next_batch_0 = default_bloom_batch _, next_batch_0 = default_bloom.generate_token(next_batch_0) _, next_batch_0 = default_bloom.generate_token(next_batch_0) @@ -198,12 +211,20 @@ def test_batch_concatenate(default_bloom, default_bloom_batch, default_multi_req for i, past in enumerate(next_batch.past_key_values): assert torch.equal(next_batch_0.past_key_values[i][0][:, :, -2:], past[0][0]) - assert torch.equal(next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, :, -1].reshape(-1, 64, 1)) + assert torch.equal( + next_batch_1.past_key_values[i][0][:, :, -1:], + past[0][1:, :, :, -1].reshape(-1, 64, 1), + ) assert torch.equal(next_batch_0.past_key_values[i][1][:, -2:, :], past[1][0]) - assert torch.equal(next_batch_1.past_key_values[i][1][:, -1:, :], past[1][1:, :, -1, :].reshape(-1, 1, 64)) + assert torch.equal( + next_batch_1.past_key_values[i][1][:, -1:, :], + past[1][1:, :, -1, :].reshape(-1, 1, 64), + ) - for _ in range(default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2): + for _ in range( + default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 + ): generated_texts, next_batch = default_bloom.generate_token(next_batch) assert generated_texts == [] @@ -213,11 +234,16 @@ def test_batch_concatenate(default_bloom, default_bloom_batch, default_multi_req assert len(generated_texts) == 1 assert generated_texts[0].output == "TestTestTestTestTestTest" assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] - assert generated_texts[0].tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + assert ( + generated_texts[0].tokens + == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + ) for _ in range( - default_bloom_batch.stopping_criterias[0].max_new_tokens - - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2): + default_bloom_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + - 2 + ): generated_texts, next_batch = default_bloom.generate_token(next_batch) assert generated_texts == [] @@ -227,12 +253,17 @@ def test_batch_concatenate(default_bloom, default_bloom_batch, default_multi_req assert len(generated_texts) == 1 assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" assert generated_texts[0].request == default_bloom_batch.requests[0] - assert generated_texts[0].tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens + assert ( + generated_texts[0].tokens + == default_bloom_batch.stopping_criterias[0].max_new_tokens + ) for _ in range( - default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens - - default_bloom_batch.stopping_criterias[0].max_new_tokens - - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 4): + default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + - default_bloom_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + - 4 + ): generated_texts, next_batch = default_bloom.generate_token(next_batch) assert generated_texts == [] @@ -242,4 +273,7 @@ def test_batch_concatenate(default_bloom, default_bloom_batch, default_multi_req assert len(generated_texts) == 1 assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] - assert generated_texts[0].tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + assert ( + generated_texts[0].tokens + == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 59befd3f..1bf3e5e6 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -20,11 +20,7 @@ def default_pb_request(default_pb_parameters): @pytest.fixture def default_pb_batch(default_pb_request): - return generate_pb2.Batch( - id=0, - requests=[default_pb_request], - size=1 - ) + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) @pytest.fixture @@ -39,11 +35,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): req_1.id = 1 req_1.max_new_tokens = 5 - batch_pb = generate_pb2.Batch( - id=0, - requests=[req_0, req_1], - size=2 - ) + batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu")) @@ -88,7 +80,9 @@ def test_causal_lm_batch_type(default_causal_lm): def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): - generated_texts, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) + generated_texts, next_batch = default_causal_lm.generate_token( + default_causal_lm_batch + ) assert generated_texts == [] assert isinstance(next_batch, CausalLMBatch) @@ -113,7 +107,9 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): assert all([p[1].shape == (1, 12, 8, 64) for p in next_batch.past_key_values]) -def test_causal_lm_generate_token_completion(default_causal_lm, default_causal_lm_batch): +def test_causal_lm_generate_token_completion( + default_causal_lm, default_causal_lm_batch +): next_batch = default_causal_lm_batch for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1): generated_texts, next_batch = default_causal_lm.generate_token(next_batch) @@ -123,15 +119,25 @@ def test_causal_lm_generate_token_completion(default_causal_lm, default_causal_l assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test Test Test Test Test Test Test Test Test Test Test" + assert ( + generated_texts[0].output + == "Test Test Test Test Test Test Test Test Test Test Test" + ) assert generated_texts[0].request == default_causal_lm_batch.requests[0] - assert generated_texts[0].tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens + assert ( + generated_texts[0].tokens + == default_causal_lm_batch.stopping_criterias[0].max_new_tokens + ) -def test_causal_lm_generate_token_completion_multi(default_causal_lm, default_multi_requests_causal_lm_batch): +def test_causal_lm_generate_token_completion_multi( + default_causal_lm, default_multi_requests_causal_lm_batch +): next_batch = default_multi_requests_causal_lm_batch - for i in range(default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1): + for i in range( + default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 + ): generated_texts, next_batch = default_causal_lm.generate_token(next_batch) assert generated_texts == [] @@ -140,12 +146,19 @@ def test_causal_lm_generate_token_completion_multi(default_causal_lm, default_mu assert len(generated_texts) == 1 assert generated_texts[0].output == "Test Test Test Test Test Test" - assert generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] - assert generated_texts[0].tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + assert ( + generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] + ) + assert ( + generated_texts[0].tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + ) for _ in range( - default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens - - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1): + default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + - 1 + ): generated_texts, next_batch = default_causal_lm.generate_token(next_batch) assert generated_texts == [] @@ -153,12 +166,22 @@ def test_causal_lm_generate_token_completion_multi(default_causal_lm, default_mu assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test Test Test Test Test Test Test Test Test Test Test" - assert generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] - assert generated_texts[0].tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + assert ( + generated_texts[0].output + == "Test Test Test Test Test Test Test Test Test Test Test" + ) + assert ( + generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] + ) + assert ( + generated_texts[0].tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + ) -def test_batch_concatenate(default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch): +def test_batch_concatenate( + default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch +): next_batch_0 = default_causal_lm_batch _, next_batch_0 = default_causal_lm.generate_token(next_batch_0) _, next_batch_0 = default_causal_lm.generate_token(next_batch_0) @@ -197,12 +220,18 @@ def test_batch_concatenate(default_causal_lm, default_causal_lm_batch, default_m for i, past in enumerate(next_batch.past_key_values): assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:], past[0][0]) - assert torch.equal(next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]) + assert torch.equal( + next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :] + ) assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:], past[1][0]) - assert torch.equal(next_batch_1.past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]) + assert torch.equal( + next_batch_1.past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :] + ) - for _ in range(default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2): + for _ in range( + default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 + ): generated_texts, next_batch = default_causal_lm.generate_token(next_batch) assert generated_texts == [] @@ -211,12 +240,19 @@ def test_batch_concatenate(default_causal_lm, default_causal_lm_batch, default_m assert len(generated_texts) == 1 assert generated_texts[0].output == "Test Test Test Test Test Test" - assert generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] - assert generated_texts[0].tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + assert ( + generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] + ) + assert ( + generated_texts[0].tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + ) for _ in range( - default_causal_lm_batch.stopping_criterias[0].max_new_tokens - - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2): + default_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + - 2 + ): generated_texts, next_batch = default_causal_lm.generate_token(next_batch) assert generated_texts == [] @@ -224,14 +260,22 @@ def test_batch_concatenate(default_causal_lm, default_causal_lm_batch, default_m assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test Test Test Test Test Test Test Test Test Test Test" + assert ( + generated_texts[0].output + == "Test Test Test Test Test Test Test Test Test Test Test" + ) assert generated_texts[0].request == default_causal_lm_batch.requests[0] - assert generated_texts[0].tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens + assert ( + generated_texts[0].tokens + == default_causal_lm_batch.stopping_criterias[0].max_new_tokens + ) for _ in range( - default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens - - default_causal_lm_batch.stopping_criterias[0].max_new_tokens - - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 4): + default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + - 4 + ): generated_texts, next_batch = default_causal_lm.generate_token(next_batch) assert generated_texts == [] @@ -239,6 +283,14 @@ def test_batch_concatenate(default_causal_lm, default_causal_lm_batch, default_m assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test Test Test Test Test Test Test Test Test Test Test" - assert generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] - assert generated_texts[0].tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + assert ( + generated_texts[0].output + == "Test Test Test Test Test Test Test Test Test Test Test" + ) + assert ( + generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] + ) + assert ( + generated_texts[0].tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 257b7b52..7e4c7fdd 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -20,16 +20,14 @@ def default_pb_request(default_pb_parameters): @pytest.fixture def default_pb_batch(default_pb_request): - return generate_pb2.Batch( - id=0, - requests=[default_pb_request], - size=1 - ) + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) @pytest.fixture def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer): - return Seq2SeqLMBatch.from_pb(default_pb_batch, mt0_small_tokenizer, torch.device("cpu")) + return Seq2SeqLMBatch.from_pb( + default_pb_batch, mt0_small_tokenizer, torch.device("cpu") + ) @pytest.fixture @@ -39,11 +37,7 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni req_1.id = 1 req_1.max_new_tokens = 5 - batch_pb = generate_pb2.Batch( - id=0, - requests=[req_0, req_1], - size=2 - ) + batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu")) @@ -92,16 +86,22 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm): def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): - generated_texts, next_batch = default_seq2seq_lm.generate_token(default_seq2seq_lm_batch) + generated_texts, next_batch = default_seq2seq_lm.generate_token( + default_seq2seq_lm_batch + ) assert generated_texts == [] assert isinstance(next_batch, Seq2SeqLMBatch) assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids) - assert torch.equal(next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask) + assert torch.equal( + next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask + ) assert next_batch.input_lengths == default_seq2seq_lm_batch.input_lengths assert next_batch.max_input_length == default_seq2seq_lm_batch.max_input_length - assert next_batch.next_token_choosers == default_seq2seq_lm_batch.next_token_choosers + assert ( + next_batch.next_token_choosers == default_seq2seq_lm_batch.next_token_choosers + ) assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias assert next_batch.decoder_input_ids.shape == (next_batch.size, 2) @@ -114,13 +114,23 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) assert next_batch.max_decoder_input_length == 2 assert next_batch.past_key_values is not None - assert all([p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]) - assert all([p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]) - assert all([p[2].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values]) - assert all([p[3].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values]) + assert all( + [p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[2].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[3].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values] + ) -def test_seq2seq_lm_generate_token_completion(default_seq2seq_lm, default_seq2seq_lm_batch): +def test_seq2seq_lm_generate_token_completion( + default_seq2seq_lm, default_seq2seq_lm_batch +): next_batch = default_seq2seq_lm_batch for _ in range(6): generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) @@ -135,7 +145,9 @@ def test_seq2seq_lm_generate_token_completion(default_seq2seq_lm, default_seq2se assert generated_texts[0].tokens == 7 -def test_seq2seq_lm_generate_token_completion_multi(default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch): +def test_seq2seq_lm_generate_token_completion_multi( + default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch +): next_batch = default_multi_requests_seq2seq_lm_batch for i in range(4): @@ -147,7 +159,10 @@ def test_seq2seq_lm_generate_token_completion_multi(default_seq2seq_lm, default_ assert len(generated_texts) == 1 assert generated_texts[0].output == "a few " - assert generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[1] + assert ( + generated_texts[0].request + == default_multi_requests_seq2seq_lm_batch.requests[1] + ) assert generated_texts[0].tokens == 5 generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) @@ -158,11 +173,18 @@ def test_seq2seq_lm_generate_token_completion_multi(default_seq2seq_lm, default_ assert len(generated_texts) == 1 assert generated_texts[0].output == "a few weeks" - assert generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[0] + assert ( + generated_texts[0].request + == default_multi_requests_seq2seq_lm_batch.requests[0] + ) assert generated_texts[0].tokens == 7 -def test_batch_concatenate(default_seq2seq_lm, default_seq2seq_lm_batch, default_multi_requests_seq2seq_lm_batch): +def test_batch_concatenate( + default_seq2seq_lm, + default_seq2seq_lm_batch, + default_multi_requests_seq2seq_lm_batch, +): next_batch_0 = default_seq2seq_lm_batch _, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0) _, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0) @@ -179,16 +201,26 @@ def test_batch_concatenate(default_seq2seq_lm, default_seq2seq_lm_batch, default assert torch.all(next_batch.attention_mask == 1) - assert torch.equal(next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]) + assert torch.equal( + next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0] + ) assert torch.all(next_batch.decoder_input_ids[1:, 0] == 0) - assert torch.equal(next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids) + assert torch.equal( + next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids + ) assert torch.all(next_batch.decoder_attention_mask[0] == 1) assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0) assert torch.all(next_batch.decoder_attention_mask[1:, -2:] == 1) - assert torch.equal(next_batch.encoder_last_hidden_state[0], next_batch_0.encoder_last_hidden_state[0, -2:]) - assert torch.equal(next_batch.encoder_last_hidden_state[1:], next_batch_1.encoder_last_hidden_state[:, -2:]) + assert torch.equal( + next_batch.encoder_last_hidden_state[0], + next_batch_0.encoder_last_hidden_state[0, -2:], + ) + assert torch.equal( + next_batch.encoder_last_hidden_state[1:], + next_batch_1.encoder_last_hidden_state[:, -2:], + ) assert next_batch.input_lengths == [2, 2, 2] assert next_batch.decoder_input_lengths == [3, 2, 2] @@ -205,23 +237,39 @@ def test_batch_concatenate(default_seq2seq_lm, default_seq2seq_lm_batch, default assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias assert next_batch.past_key_values is not None - assert all([p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]) - assert all([p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]) - assert all([p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]) - assert all([p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]) + assert all( + [p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] + ) for i, past in enumerate(next_batch.past_key_values): assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:, :], past[0][0]) - assert torch.equal(next_batch_1.past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :]) + assert torch.equal( + next_batch_1.past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :] + ) assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:, :], past[1][0]) - assert torch.equal(next_batch_1.past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :]) + assert torch.equal( + next_batch_1.past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :] + ) assert torch.equal(next_batch_0.past_key_values[i][2][0, :, -2:, :], past[2][0]) - assert torch.equal(next_batch_1.past_key_values[i][2][:, :, -2:, :], past[2][1:]) + assert torch.equal( + next_batch_1.past_key_values[i][2][:, :, -2:, :], past[2][1:] + ) assert torch.equal(next_batch_0.past_key_values[i][3][0, :, -2:, :], past[3][0]) - assert torch.equal(next_batch_1.past_key_values[i][3][:, :, -2:, :], past[3][1:]) + assert torch.equal( + next_batch_1.past_key_values[i][3][:, :, -2:, :], past[3][1:] + ) for _ in range(3): generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) @@ -232,7 +280,10 @@ def test_batch_concatenate(default_seq2seq_lm, default_seq2seq_lm_batch, default assert len(generated_texts) == 1 assert generated_texts[0].output == "a few " - assert generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[1] + assert ( + generated_texts[0].request + == default_multi_requests_seq2seq_lm_batch.requests[1] + ) assert generated_texts[0].tokens == 5 generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) @@ -248,5 +299,8 @@ def test_batch_concatenate(default_seq2seq_lm, default_seq2seq_lm_batch, default assert len(generated_texts) == 1 assert generated_texts[0].output == "a few weeks" - assert generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[0] + assert ( + generated_texts[0].request + == default_multi_requests_seq2seq_lm_batch.requests[0] + ) assert generated_texts[0].tokens == 7 diff --git a/server/tests/test_utils.py b/server/tests/test_utils.py index 4a0bc13b..e630ebda 100644 --- a/server/tests/test_utils.py +++ b/server/tests/test_utils.py @@ -1,6 +1,11 @@ import pytest -from text_generation.utils import weight_hub_files, download_weights, weight_files, LocalEntryNotFoundError +from text_generation.utils import ( + weight_hub_files, + download_weights, + weight_files, + LocalEntryNotFoundError, +) def test_weight_hub_files(): @@ -27,4 +32,3 @@ def test_download_weights(): def test_weight_files_error(): with pytest.raises(LocalEntryNotFoundError): weight_files("bert-base-uncased") - diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index d34bcc09..20e26419 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -34,9 +34,11 @@ torch.manual_seed(0) class BloomCausalLMBatch(CausalLMBatch): @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device ) -> "CausalLMBatch": - batch = super(BloomCausalLMBatch, cls).from_pb(pb=pb, tokenizer=tokenizer, device=device) + batch = super(BloomCausalLMBatch, cls).from_pb( + pb=pb, tokenizer=tokenizer, device=device + ) batch.keys_head_dim_last = False return batch @@ -105,17 +107,17 @@ class BLOOMSharded(BLOOM): @staticmethod def load_weights( - model, - filenames: List[str], - quantize: bool, - device: torch.device, - rank: int, - world_size: int, + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, ): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if not quantize else "cpu" ) as f: for name in f.keys(): full_name = f"transformer.{name}" @@ -178,9 +180,9 @@ class BLOOMSharded(BLOOM): ) if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" ): tensor = Int8Params( tensor.transpose(1, 0), diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 7d95bca6..eb6ce064 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -46,7 +46,7 @@ class CausalLMBatch: @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device ) -> "CausalLMBatch": inputs = [] next_token_choosers = [] @@ -147,8 +147,8 @@ class CausalLMBatch: # We need to slice the attention mask to remove padding from previous steps attention_mask[ - start_index:end_index, -batch.max_sequence_length: - ] = batch.attention_mask[:, -batch.max_sequence_length:] + start_index:end_index, -batch.max_sequence_length : + ] = batch.attention_mask[:, -batch.max_sequence_length :] for j, past in enumerate(batch.past_key_values): past_keys, past_values = past @@ -196,22 +196,22 @@ class CausalLMBatch: # We slice the past keys and values to remove the padding from previous batches if batch.keys_head_dim_last: past_key_values[j][0][ - start_index:end_index, - :, - -(batch.max_sequence_length - 1):, - :, - ] = past_keys[:, :, -(batch.max_sequence_length - 1):, :] + start_index:end_index, + :, + -(batch.max_sequence_length - 1) :, + :, + ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :] else: past_key_values[j][0][ - start_index:end_index, - :, - :, - -(batch.max_sequence_length - 1):, - ] = past_keys[:, :, :, -(batch.max_sequence_length - 1):] + start_index:end_index, + :, + :, + -(batch.max_sequence_length - 1) :, + ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] past_key_values[j][1][ - start_index:end_index, :, -(batch.max_sequence_length - 1):, : - ] = past_values[:, :, -(batch.max_sequence_length - 1):, :] + start_index:end_index, :, -(batch.max_sequence_length - 1) :, : + ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] start_index += batch.size @@ -227,7 +227,7 @@ class CausalLMBatch: stopping_criterias=stopping_criterias, size=total_batch_size, max_sequence_length=max_sequence_length, - keys_head_dim_last=batches[0].keys_head_dim_last + keys_head_dim_last=batches[0].keys_head_dim_last, ) @@ -250,7 +250,11 @@ class CausalLM(Model): device_map="auto" if torch.cuda.is_available() else None, load_in_8bit=quantize, ).eval() - tokenizer.pad_token_id = self.model.config.pad_token_id if self.model.config.pad_token_id is not None else self.model.config.eos_token_id + tokenizer.pad_token_id = ( + self.model.config.pad_token_id + if self.model.config.pad_token_id is not None + else self.model.config.eos_token_id + ) super(CausalLM, self).__init__( tokenizer=tokenizer, @@ -262,7 +266,7 @@ class CausalLM(Model): return CausalLMBatch def forward( - self, input_ids, attention_mask, past_key_values: Optional = None + self, input_ids, attention_mask, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward outputs = self.model.forward( @@ -274,7 +278,7 @@ class CausalLM(Model): return outputs.logits, outputs.past_key_values def generate_token( - self, batch: CausalLMBatch + self, batch: CausalLMBatch ) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]: # For some reason, inference_mode does not work well with GLOO which we use on CPU context_manager = ( @@ -312,12 +316,12 @@ class CausalLM(Model): # For each member of the batch for i, ( - request, - input_length, - logits, - next_token_chooser, - stopping_criteria, - all_tokens, + request, + input_length, + logits, + next_token_chooser, + stopping_criteria, + all_tokens, ) in enumerate(iterator): # Select next token next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) @@ -400,6 +404,6 @@ class CausalLM(Model): stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, max_sequence_length=next_batch_max_sequence_length, - keys_head_dim_last=batch.keys_head_dim_last + keys_head_dim_last=batch.keys_head_dim_last, ) return generated_texts, next_batch