use flan-t5 for tests

This commit is contained in:
OlivierDehaene 2023-03-09 14:39:00 +01:00
parent 5e1473f0f8
commit c5a0b65c47
5 changed files with 60 additions and 29 deletions

View File

@ -133,6 +133,22 @@ class FinishReason(Enum):
StopSequence = "stop_sequence"
# Additional sequences when using the `best_of` parameter
class BestOfSequence:
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# `generate` details
class Details:
# Generation finish reason
@ -145,6 +161,8 @@ class Details:
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
# `generate` return value

View File

@ -4,11 +4,6 @@ from text_generation import __version__
from huggingface_hub.utils import build_hf_headers
@pytest.fixture
def bloom_model():
return "bigscience/bloom"
@pytest.fixture
def flan_t5_xxl():
return "google/flan-t5-xxl"

View File

@ -5,21 +5,21 @@ from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, PrefillToken, Token
def test_generate(bloom_url, hf_headers):
client = Client(bloom_url, hf_headers)
def test_generate(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1)
assert response.generated_text == "."
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(
id=9234, text="test", logprob=None
id=0, text="<pad>", logprob=None
)
assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token(
id=17, text=".", logprob=-1.75, special=False
id=3, text=" ", logprob=-1.984375, special=False
)
@ -35,8 +35,8 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
client.generate("test", max_new_tokens=10_000)
def test_generate_stream(bloom_url, hf_headers):
client = Client(bloom_url, hf_headers)
def test_generate_stream(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
responses = [
response for response in client.generate_stream("test", max_new_tokens=1)
]
@ -44,7 +44,7 @@ def test_generate_stream(bloom_url, hf_headers):
assert len(responses) == 1
response = responses[0]
assert response.generated_text == "."
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
@ -63,21 +63,21 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
@pytest.mark.asyncio
async def test_generate_async(bloom_url, hf_headers):
client = AsyncClient(bloom_url, hf_headers)
async def test_generate_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate("test", max_new_tokens=1)
assert response.generated_text == "."
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(
id=9234, text="test", logprob=None
id=0, text="<pad>", logprob=None
)
assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token(
id=17, text=".", logprob=-1.75, special=False
id=3, text=" ", logprob=-1.984375, special=False
)
@ -96,8 +96,8 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
@pytest.mark.asyncio
async def test_generate_stream_async(bloom_url, hf_headers):
client = AsyncClient(bloom_url, hf_headers)
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
responses = [
response async for response in client.generate_stream("test", max_new_tokens=1)
]
@ -105,7 +105,7 @@ async def test_generate_stream_async(bloom_url, hf_headers):
assert len(responses) == 1
response = responses[0]
assert response.generated_text == "."
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None

View File

@ -14,8 +14,8 @@ def test_get_supported_models():
assert isinstance(get_supported_models(), list)
def test_client(bloom_model):
client = InferenceAPIClient(bloom_model)
def test_client(flan_t5_xxl):
client = InferenceAPIClient(flan_t5_xxl)
assert isinstance(client, Client)
@ -24,8 +24,8 @@ def test_client_unsupported_model(unsupported_model):
InferenceAPIClient(unsupported_model)
def test_async_client(bloom_model):
client = InferenceAPIAsyncClient(bloom_model)
def test_async_client(flan_t5_xxl):
client = InferenceAPIAsyncClient(flan_t5_xxl)
assert isinstance(client, AsyncClient)

View File

@ -150,6 +150,22 @@ class FinishReason(Enum):
StopSequence = "stop_sequence"
# Additional sequences when using the `best_of` parameter
class BestOfSequence(BaseModel):
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# `generate` details
class Details(BaseModel):
# Generation finish reason
@ -162,6 +178,8 @@ class Details(BaseModel):
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
# `generate` return value