From 298500a08eedaef076fa410a2cee49db079c563d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 15:13:24 +0000 Subject: [PATCH] Fixing the simple tests. --- server/tests/models/test_bloom.py | 8 +++++++- server/tests/models/test_causal_lm.py | 2 +- server/tests/models/test_seq2seq_lm.py | 2 +- server/text_generation_server/models/causal_lm.py | 7 ++++++- server/text_generation_server/models/seq2seq_lm.py | 7 ++++++- 5 files changed, 21 insertions(+), 5 deletions(-) diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 32ee6686..08292920 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.utils import weight_hub_files, download_weights from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) @pytest.fixture(scope="session") @@ -16,7 +19,10 @@ def default_bloom(): revision = "main" filenames = weight_hub_files(model_id, revision, ".safetensors") download_weights(filenames, model_id, revision) - return BLOOMSharded(model_id) + return BLOOMSharded( + model_id, + model_class=BloomForCausalLM, + ) @pytest.fixture(scope="session") diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 6e6463bc..c000ef26 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch @pytest.fixture(scope="session") def default_causal_lm(): - return CausalLM("gpt2") + return CausalLM.fallback("gpt2") @pytest.fixture(scope="session") diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 943c3b08..02666042 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -20,7 +20,7 @@ def mt0_small_tokenizer(): @pytest.fixture(scope="session") def default_seq2seq_lm(): - return Seq2SeqLM("bigscience/mt0-small") + return Seq2SeqLM.fallback("bigscience/mt0-small") @pytest.fixture diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 71a59fee..685177c7 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -609,7 +609,11 @@ class CausalLM(Model): else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - super(CausalLM, cls).__init__( + self = cls.__new__( + cls, + ) + super().__init__( + self, model_id=model_id, model=model, tokenizer=tokenizer, @@ -617,6 +621,7 @@ class CausalLM(Model): dtype=dtype, device=device, ) + return self @property def batch_type(self) -> Type[CausalLMBatch]: diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index e3684071..38695b19 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -650,7 +650,11 @@ class Seq2SeqLM(Model): ) tokenizer.bos_token_id = model.config.decoder_start_token_id - super(Seq2SeqLM, cls).__init__( + self = cls.__new__( + cls, + ) + super().__init__( + self, model_id=model_id, model=model, tokenizer=tokenizer, @@ -658,6 +662,7 @@ class Seq2SeqLM(Model): dtype=dtype, device=device, ) + return self @property def batch_type(self) -> Type[Seq2SeqLMBatch]: