From caa960834727283e72f519a75a9355e37ed84f98 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 24 May 2023 17:02:20 +0200 Subject: [PATCH] fix tests --- server/tests/models/test_bloom.py | 4 ++-- server/tests/models/test_causal_lm.py | 8 ++++++-- server/tests/models/test_santacoder.py | 10 ++++++++-- server/tests/models/test_seq2seq_lm.py | 6 ++++-- .../text_generation_server/models/flash_causal_lm.py | 2 +- 5 files changed, 21 insertions(+), 9 deletions(-) diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 105b3573..65f9b4dd 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -38,7 +38,7 @@ def default_pb_batch(default_pb_request): @pytest.fixture def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer): return BloomCausalLMBatch.from_pb( - default_pb_batch, bloom_560m_tokenizer, torch.device("cpu") + default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu") ) @@ -52,7 +52,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer) batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) return BloomCausalLMBatch.from_pb( - batch_pb, bloom_560m_tokenizer, torch.device("cpu") + batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu") ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index d8d1bd16..43676ea2 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -38,7 +38,9 @@ def default_pb_batch(default_pb_request): @pytest.fixture def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer): - return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu")) + return CausalLMBatch.from_pb( + default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu") + ) @pytest.fixture @@ -50,7 +52,9 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): req_1.stopping_parameters.max_new_tokens = 5 batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2) - return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu")) + return CausalLMBatch.from_pb( + batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu") + ) def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index 8cf66d47..bef8db38 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -45,7 +45,10 @@ def default_fim_pb_batch(default_fim_pb_request): @pytest.mark.skip def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch): batch = CausalLMBatch.from_pb( - default_pb_batch, default_santacoder.tokenizer, default_santacoder.device + default_pb_batch, + default_santacoder.tokenizer, + default_santacoder.dtype, + default_santacoder.device, ) next_batch = batch @@ -70,7 +73,10 @@ def test_fim_santacoder_generate_token_completion( default_santacoder, default_fim_pb_batch ): batch = CausalLMBatch.from_pb( - default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device + default_fim_pb_batch, + default_santacoder.tokenizer, + default_santacoder.dtype, + default_santacoder.device, ) next_batch = batch diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 8fdeee60..e043a5e4 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -42,7 +42,7 @@ def default_pb_batch(default_pb_request): @pytest.fixture def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer): return Seq2SeqLMBatch.from_pb( - default_pb_batch, mt0_small_tokenizer, torch.device("cpu") + default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu") ) @@ -55,7 +55,9 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni req_1.stopping_parameters.max_new_tokens = 5 batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) - return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu")) + return Seq2SeqLMBatch.from_pb( + batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu") + ) def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f1535a77..35cbe174 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -393,7 +393,7 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] - ] = batch.all_input_ids_tensor + ] = batch.all_input_ids_tensor[:, :max_length] cumulative_batch_size += len(batch)