mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing the simple tests.
This commit is contained in:
parent
db9acc4418
commit
298500a08e
@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2
|
|||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation_server.utils import weight_hub_files, download_weights
|
from text_generation_server.utils import weight_hub_files, download_weights
|
||||||
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
|
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
|
||||||
|
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||||
|
BloomForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -16,7 +19,10 @@ def default_bloom():
|
|||||||
revision = "main"
|
revision = "main"
|
||||||
filenames = weight_hub_files(model_id, revision, ".safetensors")
|
filenames = weight_hub_files(model_id, revision, ".safetensors")
|
||||||
download_weights(filenames, model_id, revision)
|
download_weights(filenames, model_id, revision)
|
||||||
return BLOOMSharded(model_id)
|
return BLOOMSharded(
|
||||||
|
model_id,
|
||||||
|
model_class=BloomForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def default_causal_lm():
|
def default_causal_lm():
|
||||||
return CausalLM("gpt2")
|
return CausalLM.fallback("gpt2")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -20,7 +20,7 @@ def mt0_small_tokenizer():
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def default_seq2seq_lm():
|
def default_seq2seq_lm():
|
||||||
return Seq2SeqLM("bigscience/mt0-small")
|
return Seq2SeqLM.fallback("bigscience/mt0-small")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -609,7 +609,11 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
super(CausalLM, cls).__init__(
|
self = cls.__new__(
|
||||||
|
cls,
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
self,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -617,6 +621,7 @@ class CausalLM(Model):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
|
@ -650,7 +650,11 @@ class Seq2SeqLM(Model):
|
|||||||
)
|
)
|
||||||
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
||||||
|
|
||||||
super(Seq2SeqLM, cls).__init__(
|
self = cls.__new__(
|
||||||
|
cls,
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
self,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -658,6 +662,7 @@ class Seq2SeqLM(Model):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[Seq2SeqLMBatch]:
|
def batch_type(self) -> Type[Seq2SeqLMBatch]:
|
||||||
|
Loading…
Reference in New Issue
Block a user