mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 12:02:08 +00:00
42 lines
988 B
Python
42 lines
988 B
Python
import pytest
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
from text_generation.pb import generate_pb2
|
|
|
|
|
|
@pytest.fixture
|
|
def default_pb_parameters():
|
|
return generate_pb2.NextTokenChooserParameters(
|
|
temperature=1.0,
|
|
top_k=0,
|
|
top_p=1.0,
|
|
do_sample=False,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def default_pb_stop_parameters():
|
|
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def bloom_560m_tokenizer():
|
|
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def gpt2_tokenizer():
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
|
|
tokenizer.pad_token_id = 50256
|
|
return tokenizer
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def mt0_small_tokenizer():
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
"bigscience/mt0-small", padding_side="left"
|
|
)
|
|
tokenizer.bos_token_id = 0
|
|
return tokenizer
|