This commit is contained in:
OlivierDehaene 2023-03-07 18:13:34 +01:00
parent 6e9e194f33
commit 478d5c1403
6 changed files with 21 additions and 11 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation"
version = "0.3.2" version = "0.1.0"
description = "Hugging Face Text Generation Python Client" description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -9,6 +9,11 @@ def bloom_model():
return "bigscience/bloom" return "bigscience/bloom"
@pytest.fixture
def flan_t5_xxl():
return "google/flan-t5-xxl"
@pytest.fixture @pytest.fixture
def fake_model(): def fake_model():
return "fake/model" return "fake/model"
@ -29,6 +34,11 @@ def bloom_url(base_url, bloom_model):
return f"{base_url}/{bloom_model}" return f"{base_url}/{bloom_model}"
@pytest.fixture
def flan_t5_xxl_url(base_url, flan_t5_xxl):
return f"{base_url}/{flan_t5_xxl}"
@pytest.fixture @pytest.fixture
def fake_url(base_url, fake_model): def fake_url(base_url, fake_model):
return f"{base_url}/{fake_model}" return f"{base_url}/{fake_model}"

View File

@ -29,8 +29,8 @@ def test_generate_not_found(fake_url, hf_headers):
client.generate("test") client.generate("test")
def test_generate_validation_error(bloom_url, hf_headers): def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
client = Client(bloom_url, hf_headers) client = Client(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
client.generate("test", max_new_tokens=10_000) client.generate("test", max_new_tokens=10_000)
@ -56,8 +56,8 @@ def test_generate_stream_not_found(fake_url, hf_headers):
list(client.generate_stream("test")) list(client.generate_stream("test"))
def test_generate_stream_validation_error(bloom_url, hf_headers): def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
client = Client(bloom_url, hf_headers) client = Client(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
list(client.generate_stream("test", max_new_tokens=10_000)) list(client.generate_stream("test", max_new_tokens=10_000))
@ -89,8 +89,8 @@ async def test_generate_async_not_found(fake_url, hf_headers):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_async_validation_error(bloom_url, hf_headers): async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
client = AsyncClient(bloom_url, hf_headers) client = AsyncClient(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
await client.generate("test", max_new_tokens=10_000) await client.generate("test", max_new_tokens=10_000)
@ -120,8 +120,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_stream_async_validation_error(bloom_url, hf_headers): async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
client = AsyncClient(bloom_url, hf_headers) client = AsyncClient(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
async for _ in client.generate_stream("test", max_new_tokens=10_000): async for _ in client.generate_stream("test", max_new_tokens=10_000):
pass pass

View File

@ -7,7 +7,7 @@ from text_generation import (
AsyncClient, AsyncClient,
) )
from text_generation.errors import NotSupportedError from text_generation.errors import NotSupportedError
from text_generation.api_inference import get_supported_models from text_generation.inference_api import get_supported_models
def test_get_supported_models(): def test_get_supported_models():

View File

@ -15,4 +15,4 @@
__version__ = "0.3.2" __version__ = "0.3.2"
from text_generation.client import Client, AsyncClient from text_generation.client import Client, AsyncClient
from text_generation.api_inference import InferenceAPIClient, InferenceAPIAsyncClient from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient