use flan-t5 for tests

This commit is contained in:
OlivierDehaene 2023-03-09 14:39:00 +01:00
parent 5e1473f0f8
commit c5a0b65c47
5 changed files with 60 additions and 29 deletions

View File

@ -133,6 +133,22 @@ class FinishReason(Enum):
StopSequence = "stop_sequence" StopSequence = "stop_sequence"
# Additional sequences when using the `best_of` parameter
class BestOfSequence:
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# `generate` details # `generate` details
class Details: class Details:
# Generation finish reason # Generation finish reason
@ -145,6 +161,8 @@ class Details:
prefill: List[PrefillToken] prefill: List[PrefillToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
# `generate` return value # `generate` return value

View File

@ -4,11 +4,6 @@ from text_generation import __version__
from huggingface_hub.utils import build_hf_headers from huggingface_hub.utils import build_hf_headers
@pytest.fixture
def bloom_model():
return "bigscience/bloom"
@pytest.fixture @pytest.fixture
def flan_t5_xxl(): def flan_t5_xxl():
return "google/flan-t5-xxl" return "google/flan-t5-xxl"

View File

@ -5,21 +5,21 @@ from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, PrefillToken, Token from text_generation.types import FinishReason, PrefillToken, Token
def test_generate(bloom_url, hf_headers): def test_generate(flan_t5_xxl_url, hf_headers):
client = Client(bloom_url, hf_headers) client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1) response = client.generate("test", max_new_tokens=1)
assert response.generated_text == "." assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
assert len(response.details.prefill) == 1 assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken( assert response.details.prefill[0] == PrefillToken(
id=9234, text="test", logprob=None id=0, text="<pad>", logprob=None
) )
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token( assert response.details.tokens[0] == Token(
id=17, text=".", logprob=-1.75, special=False id=3, text=" ", logprob=-1.984375, special=False
) )
@ -35,8 +35,8 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
client.generate("test", max_new_tokens=10_000) client.generate("test", max_new_tokens=10_000)
def test_generate_stream(bloom_url, hf_headers): def test_generate_stream(flan_t5_xxl_url, hf_headers):
client = Client(bloom_url, hf_headers) client = Client(flan_t5_xxl_url, hf_headers)
responses = [ responses = [
response for response in client.generate_stream("test", max_new_tokens=1) response for response in client.generate_stream("test", max_new_tokens=1)
] ]
@ -44,7 +44,7 @@ def test_generate_stream(bloom_url, hf_headers):
assert len(responses) == 1 assert len(responses) == 1
response = responses[0] response = responses[0]
assert response.generated_text == "." assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
@ -63,21 +63,21 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_async(bloom_url, hf_headers): async def test_generate_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(bloom_url, hf_headers) client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate("test", max_new_tokens=1) response = await client.generate("test", max_new_tokens=1)
assert response.generated_text == "." assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
assert len(response.details.prefill) == 1 assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken( assert response.details.prefill[0] == PrefillToken(
id=9234, text="test", logprob=None id=0, text="<pad>", logprob=None
) )
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token( assert response.details.tokens[0] == Token(
id=17, text=".", logprob=-1.75, special=False id=3, text=" ", logprob=-1.984375, special=False
) )
@ -96,8 +96,8 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_stream_async(bloom_url, hf_headers): async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(bloom_url, hf_headers) client = AsyncClient(flan_t5_xxl_url, hf_headers)
responses = [ responses = [
response async for response in client.generate_stream("test", max_new_tokens=1) response async for response in client.generate_stream("test", max_new_tokens=1)
] ]
@ -105,7 +105,7 @@ async def test_generate_stream_async(bloom_url, hf_headers):
assert len(responses) == 1 assert len(responses) == 1
response = responses[0] response = responses[0]
assert response.generated_text == "." assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None

View File

@ -14,8 +14,8 @@ def test_get_supported_models():
assert isinstance(get_supported_models(), list) assert isinstance(get_supported_models(), list)
def test_client(bloom_model): def test_client(flan_t5_xxl):
client = InferenceAPIClient(bloom_model) client = InferenceAPIClient(flan_t5_xxl)
assert isinstance(client, Client) assert isinstance(client, Client)
@ -24,8 +24,8 @@ def test_client_unsupported_model(unsupported_model):
InferenceAPIClient(unsupported_model) InferenceAPIClient(unsupported_model)
def test_async_client(bloom_model): def test_async_client(flan_t5_xxl):
client = InferenceAPIAsyncClient(bloom_model) client = InferenceAPIAsyncClient(flan_t5_xxl)
assert isinstance(client, AsyncClient) assert isinstance(client, AsyncClient)

View File

@ -105,10 +105,10 @@ class Request(BaseModel):
def valid_best_of_stream(cls, field_value, values): def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"] parameters = values["parameters"]
if ( if (
parameters is not None parameters is not None
and parameters.best_of is not None and parameters.best_of is not None
and parameters.best_of > 1 and parameters.best_of > 1
and field_value and field_value
): ):
raise ValidationError( raise ValidationError(
"`best_of` != 1 is not supported when `stream` == True" "`best_of` != 1 is not supported when `stream` == True"
@ -150,6 +150,22 @@ class FinishReason(Enum):
StopSequence = "stop_sequence" StopSequence = "stop_sequence"
# Additional sequences when using the `best_of` parameter
class BestOfSequence(BaseModel):
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# `generate` details # `generate` details
class Details(BaseModel): class Details(BaseModel):
# Generation finish reason # Generation finish reason
@ -162,6 +178,8 @@ class Details(BaseModel):
prefill: List[PrefillToken] prefill: List[PrefillToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
# `generate` return value # `generate` return value