mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
use flan-t5 for tests
This commit is contained in:
parent
5e1473f0f8
commit
c5a0b65c47
@ -133,6 +133,22 @@ class FinishReason(Enum):
|
||||
StopSequence = "stop_sequence"
|
||||
|
||||
|
||||
# Additional sequences when using the `best_of` parameter
|
||||
class BestOfSequence:
|
||||
# Generated text
|
||||
generated_text: str
|
||||
# Generation finish reason
|
||||
finish_reason: FinishReason
|
||||
# Number of generated tokens
|
||||
generated_tokens: int
|
||||
# Sampling seed if sampling was activated
|
||||
seed: Optional[int]
|
||||
# Prompt tokens
|
||||
prefill: List[PrefillToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
|
||||
|
||||
# `generate` details
|
||||
class Details:
|
||||
# Generation finish reason
|
||||
@ -145,6 +161,8 @@ class Details:
|
||||
prefill: List[PrefillToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
# Additional sequences when using the `best_of` parameter
|
||||
best_of_sequences: Optional[List[BestOfSequence]]
|
||||
|
||||
|
||||
# `generate` return value
|
||||
|
@ -4,11 +4,6 @@ from text_generation import __version__
|
||||
from huggingface_hub.utils import build_hf_headers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bloom_model():
|
||||
return "bigscience/bloom"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flan_t5_xxl():
|
||||
return "google/flan-t5-xxl"
|
||||
|
@ -5,21 +5,21 @@ from text_generation.errors import NotFoundError, ValidationError
|
||||
from text_generation.types import FinishReason, PrefillToken, Token
|
||||
|
||||
|
||||
def test_generate(bloom_url, hf_headers):
|
||||
client = Client(bloom_url, hf_headers)
|
||||
def test_generate(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
response = client.generate("test", max_new_tokens=1)
|
||||
|
||||
assert response.generated_text == "."
|
||||
assert response.generated_text == ""
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
assert len(response.details.prefill) == 1
|
||||
assert response.details.prefill[0] == PrefillToken(
|
||||
id=9234, text="test", logprob=None
|
||||
id=0, text="<pad>", logprob=None
|
||||
)
|
||||
assert len(response.details.tokens) == 1
|
||||
assert response.details.tokens[0] == Token(
|
||||
id=17, text=".", logprob=-1.75, special=False
|
||||
id=3, text=" ", logprob=-1.984375, special=False
|
||||
)
|
||||
|
||||
|
||||
@ -35,8 +35,8 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client.generate("test", max_new_tokens=10_000)
|
||||
|
||||
|
||||
def test_generate_stream(bloom_url, hf_headers):
|
||||
client = Client(bloom_url, hf_headers)
|
||||
def test_generate_stream(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
responses = [
|
||||
response for response in client.generate_stream("test", max_new_tokens=1)
|
||||
]
|
||||
@ -44,7 +44,7 @@ def test_generate_stream(bloom_url, hf_headers):
|
||||
assert len(responses) == 1
|
||||
response = responses[0]
|
||||
|
||||
assert response.generated_text == "."
|
||||
assert response.generated_text == ""
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
@ -63,21 +63,21 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_async(bloom_url, hf_headers):
|
||||
client = AsyncClient(bloom_url, hf_headers)
|
||||
async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
response = await client.generate("test", max_new_tokens=1)
|
||||
|
||||
assert response.generated_text == "."
|
||||
assert response.generated_text == ""
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
assert len(response.details.prefill) == 1
|
||||
assert response.details.prefill[0] == PrefillToken(
|
||||
id=9234, text="test", logprob=None
|
||||
id=0, text="<pad>", logprob=None
|
||||
)
|
||||
assert len(response.details.tokens) == 1
|
||||
assert response.details.tokens[0] == Token(
|
||||
id=17, text=".", logprob=-1.75, special=False
|
||||
id=3, text=" ", logprob=-1.984375, special=False
|
||||
)
|
||||
|
||||
|
||||
@ -96,8 +96,8 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_stream_async(bloom_url, hf_headers):
|
||||
client = AsyncClient(bloom_url, hf_headers)
|
||||
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
responses = [
|
||||
response async for response in client.generate_stream("test", max_new_tokens=1)
|
||||
]
|
||||
@ -105,7 +105,7 @@ async def test_generate_stream_async(bloom_url, hf_headers):
|
||||
assert len(responses) == 1
|
||||
response = responses[0]
|
||||
|
||||
assert response.generated_text == "."
|
||||
assert response.generated_text == ""
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
|
@ -14,8 +14,8 @@ def test_get_supported_models():
|
||||
assert isinstance(get_supported_models(), list)
|
||||
|
||||
|
||||
def test_client(bloom_model):
|
||||
client = InferenceAPIClient(bloom_model)
|
||||
def test_client(flan_t5_xxl):
|
||||
client = InferenceAPIClient(flan_t5_xxl)
|
||||
assert isinstance(client, Client)
|
||||
|
||||
|
||||
@ -24,8 +24,8 @@ def test_client_unsupported_model(unsupported_model):
|
||||
InferenceAPIClient(unsupported_model)
|
||||
|
||||
|
||||
def test_async_client(bloom_model):
|
||||
client = InferenceAPIAsyncClient(bloom_model)
|
||||
def test_async_client(flan_t5_xxl):
|
||||
client = InferenceAPIAsyncClient(flan_t5_xxl)
|
||||
assert isinstance(client, AsyncClient)
|
||||
|
||||
|
||||
|
@ -105,10 +105,10 @@ class Request(BaseModel):
|
||||
def valid_best_of_stream(cls, field_value, values):
|
||||
parameters = values["parameters"]
|
||||
if (
|
||||
parameters is not None
|
||||
and parameters.best_of is not None
|
||||
and parameters.best_of > 1
|
||||
and field_value
|
||||
parameters is not None
|
||||
and parameters.best_of is not None
|
||||
and parameters.best_of > 1
|
||||
and field_value
|
||||
):
|
||||
raise ValidationError(
|
||||
"`best_of` != 1 is not supported when `stream` == True"
|
||||
@ -150,6 +150,22 @@ class FinishReason(Enum):
|
||||
StopSequence = "stop_sequence"
|
||||
|
||||
|
||||
# Additional sequences when using the `best_of` parameter
|
||||
class BestOfSequence(BaseModel):
|
||||
# Generated text
|
||||
generated_text: str
|
||||
# Generation finish reason
|
||||
finish_reason: FinishReason
|
||||
# Number of generated tokens
|
||||
generated_tokens: int
|
||||
# Sampling seed if sampling was activated
|
||||
seed: Optional[int]
|
||||
# Prompt tokens
|
||||
prefill: List[PrefillToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
|
||||
|
||||
# `generate` details
|
||||
class Details(BaseModel):
|
||||
# Generation finish reason
|
||||
@ -162,6 +178,8 @@ class Details(BaseModel):
|
||||
prefill: List[PrefillToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
# Additional sequences when using the `best_of` parameter
|
||||
best_of_sequences: Optional[List[BestOfSequence]]
|
||||
|
||||
|
||||
# `generate` return value
|
||||
|
Loading…
Reference in New Issue
Block a user