Fixing the simple tests.

This commit is contained in:
Nicolas Patry 2024-07-02 15:13:24 +00:00
parent db9acc4418
commit 298500a08e
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
5 changed files with 21 additions and 5 deletions

View File

@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
@pytest.fixture(scope="session")
@ -16,7 +19,10 @@ def default_bloom():
revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision)
return BLOOMSharded(model_id)
return BLOOMSharded(
model_id,
model_class=BloomForCausalLM,
)
@pytest.fixture(scope="session")

View File

@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM("gpt2")
return CausalLM.fallback("gpt2")
@pytest.fixture(scope="session")

View File

@ -20,7 +20,7 @@ def mt0_small_tokenizer():
@pytest.fixture(scope="session")
def default_seq2seq_lm():
return Seq2SeqLM("bigscience/mt0-small")
return Seq2SeqLM.fallback("bigscience/mt0-small")
@pytest.fixture

View File

@ -609,7 +609,11 @@ class CausalLM(Model):
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, cls).__init__(
self = cls.__new__(
cls,
)
super().__init__(
self,
model_id=model_id,
model=model,
tokenizer=tokenizer,
@ -617,6 +621,7 @@ class CausalLM(Model):
dtype=dtype,
device=device,
)
return self
@property
def batch_type(self) -> Type[CausalLMBatch]:

View File

@ -650,7 +650,11 @@ class Seq2SeqLM(Model):
)
tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, cls).__init__(
self = cls.__new__(
cls,
)
super().__init__(
self,
model_id=model_id,
model=model,
tokenizer=tokenizer,
@ -658,6 +662,7 @@ class Seq2SeqLM(Model):
dtype=dtype,
device=device,
)
return self
@property
def batch_type(self) -> Type[Seq2SeqLMBatch]: