mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
use flan-t5 for tests
This commit is contained in:
parent
5e1473f0f8
commit
c5a0b65c47
@ -133,6 +133,22 @@ class FinishReason(Enum):
|
|||||||
StopSequence = "stop_sequence"
|
StopSequence = "stop_sequence"
|
||||||
|
|
||||||
|
|
||||||
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
class BestOfSequence:
|
||||||
|
# Generated text
|
||||||
|
generated_text: str
|
||||||
|
# Generation finish reason
|
||||||
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
|
seed: Optional[int]
|
||||||
|
# Prompt tokens
|
||||||
|
prefill: List[PrefillToken]
|
||||||
|
# Generated tokens
|
||||||
|
tokens: List[Token]
|
||||||
|
|
||||||
|
|
||||||
# `generate` details
|
# `generate` details
|
||||||
class Details:
|
class Details:
|
||||||
# Generation finish reason
|
# Generation finish reason
|
||||||
@ -145,6 +161,8 @@ class Details:
|
|||||||
prefill: List[PrefillToken]
|
prefill: List[PrefillToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
best_of_sequences: Optional[List[BestOfSequence]]
|
||||||
|
|
||||||
|
|
||||||
# `generate` return value
|
# `generate` return value
|
||||||
|
@ -4,11 +4,6 @@ from text_generation import __version__
|
|||||||
from huggingface_hub.utils import build_hf_headers
|
from huggingface_hub.utils import build_hf_headers
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def bloom_model():
|
|
||||||
return "bigscience/bloom"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def flan_t5_xxl():
|
def flan_t5_xxl():
|
||||||
return "google/flan-t5-xxl"
|
return "google/flan-t5-xxl"
|
||||||
|
@ -5,21 +5,21 @@ from text_generation.errors import NotFoundError, ValidationError
|
|||||||
from text_generation.types import FinishReason, PrefillToken, Token
|
from text_generation.types import FinishReason, PrefillToken, Token
|
||||||
|
|
||||||
|
|
||||||
def test_generate(bloom_url, hf_headers):
|
def test_generate(flan_t5_xxl_url, hf_headers):
|
||||||
client = Client(bloom_url, hf_headers)
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
response = client.generate("test", max_new_tokens=1)
|
response = client.generate("test", max_new_tokens=1)
|
||||||
|
|
||||||
assert response.generated_text == "."
|
assert response.generated_text == ""
|
||||||
assert response.details.finish_reason == FinishReason.Length
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
assert response.details.generated_tokens == 1
|
assert response.details.generated_tokens == 1
|
||||||
assert response.details.seed is None
|
assert response.details.seed is None
|
||||||
assert len(response.details.prefill) == 1
|
assert len(response.details.prefill) == 1
|
||||||
assert response.details.prefill[0] == PrefillToken(
|
assert response.details.prefill[0] == PrefillToken(
|
||||||
id=9234, text="test", logprob=None
|
id=0, text="<pad>", logprob=None
|
||||||
)
|
)
|
||||||
assert len(response.details.tokens) == 1
|
assert len(response.details.tokens) == 1
|
||||||
assert response.details.tokens[0] == Token(
|
assert response.details.tokens[0] == Token(
|
||||||
id=17, text=".", logprob=-1.75, special=False
|
id=3, text=" ", logprob=-1.984375, special=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -35,8 +35,8 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
|
|||||||
client.generate("test", max_new_tokens=10_000)
|
client.generate("test", max_new_tokens=10_000)
|
||||||
|
|
||||||
|
|
||||||
def test_generate_stream(bloom_url, hf_headers):
|
def test_generate_stream(flan_t5_xxl_url, hf_headers):
|
||||||
client = Client(bloom_url, hf_headers)
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
responses = [
|
responses = [
|
||||||
response for response in client.generate_stream("test", max_new_tokens=1)
|
response for response in client.generate_stream("test", max_new_tokens=1)
|
||||||
]
|
]
|
||||||
@ -44,7 +44,7 @@ def test_generate_stream(bloom_url, hf_headers):
|
|||||||
assert len(responses) == 1
|
assert len(responses) == 1
|
||||||
response = responses[0]
|
response = responses[0]
|
||||||
|
|
||||||
assert response.generated_text == "."
|
assert response.generated_text == ""
|
||||||
assert response.details.finish_reason == FinishReason.Length
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
assert response.details.generated_tokens == 1
|
assert response.details.generated_tokens == 1
|
||||||
assert response.details.seed is None
|
assert response.details.seed is None
|
||||||
@ -63,21 +63,21 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_async(bloom_url, hf_headers):
|
async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
||||||
client = AsyncClient(bloom_url, hf_headers)
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
response = await client.generate("test", max_new_tokens=1)
|
response = await client.generate("test", max_new_tokens=1)
|
||||||
|
|
||||||
assert response.generated_text == "."
|
assert response.generated_text == ""
|
||||||
assert response.details.finish_reason == FinishReason.Length
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
assert response.details.generated_tokens == 1
|
assert response.details.generated_tokens == 1
|
||||||
assert response.details.seed is None
|
assert response.details.seed is None
|
||||||
assert len(response.details.prefill) == 1
|
assert len(response.details.prefill) == 1
|
||||||
assert response.details.prefill[0] == PrefillToken(
|
assert response.details.prefill[0] == PrefillToken(
|
||||||
id=9234, text="test", logprob=None
|
id=0, text="<pad>", logprob=None
|
||||||
)
|
)
|
||||||
assert len(response.details.tokens) == 1
|
assert len(response.details.tokens) == 1
|
||||||
assert response.details.tokens[0] == Token(
|
assert response.details.tokens[0] == Token(
|
||||||
id=17, text=".", logprob=-1.75, special=False
|
id=3, text=" ", logprob=-1.984375, special=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -96,8 +96,8 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_stream_async(bloom_url, hf_headers):
|
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
|
||||||
client = AsyncClient(bloom_url, hf_headers)
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
responses = [
|
responses = [
|
||||||
response async for response in client.generate_stream("test", max_new_tokens=1)
|
response async for response in client.generate_stream("test", max_new_tokens=1)
|
||||||
]
|
]
|
||||||
@ -105,7 +105,7 @@ async def test_generate_stream_async(bloom_url, hf_headers):
|
|||||||
assert len(responses) == 1
|
assert len(responses) == 1
|
||||||
response = responses[0]
|
response = responses[0]
|
||||||
|
|
||||||
assert response.generated_text == "."
|
assert response.generated_text == ""
|
||||||
assert response.details.finish_reason == FinishReason.Length
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
assert response.details.generated_tokens == 1
|
assert response.details.generated_tokens == 1
|
||||||
assert response.details.seed is None
|
assert response.details.seed is None
|
||||||
|
@ -14,8 +14,8 @@ def test_get_supported_models():
|
|||||||
assert isinstance(get_supported_models(), list)
|
assert isinstance(get_supported_models(), list)
|
||||||
|
|
||||||
|
|
||||||
def test_client(bloom_model):
|
def test_client(flan_t5_xxl):
|
||||||
client = InferenceAPIClient(bloom_model)
|
client = InferenceAPIClient(flan_t5_xxl)
|
||||||
assert isinstance(client, Client)
|
assert isinstance(client, Client)
|
||||||
|
|
||||||
|
|
||||||
@ -24,8 +24,8 @@ def test_client_unsupported_model(unsupported_model):
|
|||||||
InferenceAPIClient(unsupported_model)
|
InferenceAPIClient(unsupported_model)
|
||||||
|
|
||||||
|
|
||||||
def test_async_client(bloom_model):
|
def test_async_client(flan_t5_xxl):
|
||||||
client = InferenceAPIAsyncClient(bloom_model)
|
client = InferenceAPIAsyncClient(flan_t5_xxl)
|
||||||
assert isinstance(client, AsyncClient)
|
assert isinstance(client, AsyncClient)
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,10 +105,10 @@ class Request(BaseModel):
|
|||||||
def valid_best_of_stream(cls, field_value, values):
|
def valid_best_of_stream(cls, field_value, values):
|
||||||
parameters = values["parameters"]
|
parameters = values["parameters"]
|
||||||
if (
|
if (
|
||||||
parameters is not None
|
parameters is not None
|
||||||
and parameters.best_of is not None
|
and parameters.best_of is not None
|
||||||
and parameters.best_of > 1
|
and parameters.best_of > 1
|
||||||
and field_value
|
and field_value
|
||||||
):
|
):
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
"`best_of` != 1 is not supported when `stream` == True"
|
"`best_of` != 1 is not supported when `stream` == True"
|
||||||
@ -150,6 +150,22 @@ class FinishReason(Enum):
|
|||||||
StopSequence = "stop_sequence"
|
StopSequence = "stop_sequence"
|
||||||
|
|
||||||
|
|
||||||
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
class BestOfSequence(BaseModel):
|
||||||
|
# Generated text
|
||||||
|
generated_text: str
|
||||||
|
# Generation finish reason
|
||||||
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
|
seed: Optional[int]
|
||||||
|
# Prompt tokens
|
||||||
|
prefill: List[PrefillToken]
|
||||||
|
# Generated tokens
|
||||||
|
tokens: List[Token]
|
||||||
|
|
||||||
|
|
||||||
# `generate` details
|
# `generate` details
|
||||||
class Details(BaseModel):
|
class Details(BaseModel):
|
||||||
# Generation finish reason
|
# Generation finish reason
|
||||||
@ -162,6 +178,8 @@ class Details(BaseModel):
|
|||||||
prefill: List[PrefillToken]
|
prefill: List[PrefillToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
best_of_sequences: Optional[List[BestOfSequence]]
|
||||||
|
|
||||||
|
|
||||||
# `generate` return value
|
# `generate` return value
|
||||||
|
Loading…
Reference in New Issue
Block a user