diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 895a351b..6eb11638 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation" -version = "0.3.2" +version = "0.1.0" description = "Hugging Face Text Generation Python Client" license = "Apache-2.0" authors = ["Olivier Dehaene "] diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py index 7aa296cf..4298623e 100644 --- a/clients/python/tests/conftest.py +++ b/clients/python/tests/conftest.py @@ -9,6 +9,11 @@ def bloom_model(): return "bigscience/bloom" +@pytest.fixture +def flan_t5_xxl(): + return "google/flan-t5-xxl" + + @pytest.fixture def fake_model(): return "fake/model" @@ -29,6 +34,11 @@ def bloom_url(base_url, bloom_model): return f"{base_url}/{bloom_model}" +@pytest.fixture +def flan_t5_xxl_url(base_url, flan_t5_xxl): + return f"{base_url}/{flan_t5_xxl}" + + @pytest.fixture def fake_url(base_url, fake_model): return f"{base_url}/{fake_model}" diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index e9a6684d..2f96aa87 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -29,8 +29,8 @@ def test_generate_not_found(fake_url, hf_headers): client.generate("test") -def test_generate_validation_error(bloom_url, hf_headers): - client = Client(bloom_url, hf_headers) +def test_generate_validation_error(flan_t5_xxl_url, hf_headers): + client = Client(flan_t5_xxl_url, hf_headers) with pytest.raises(ValidationError): client.generate("test", max_new_tokens=10_000) @@ -56,8 +56,8 @@ def test_generate_stream_not_found(fake_url, hf_headers): list(client.generate_stream("test")) -def test_generate_stream_validation_error(bloom_url, hf_headers): - client = Client(bloom_url, hf_headers) +def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers): + client = Client(flan_t5_xxl_url, hf_headers) with pytest.raises(ValidationError): list(client.generate_stream("test", max_new_tokens=10_000)) @@ -89,8 +89,8 @@ async def test_generate_async_not_found(fake_url, hf_headers): @pytest.mark.asyncio -async def test_generate_async_validation_error(bloom_url, hf_headers): - client = AsyncClient(bloom_url, hf_headers) +async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers): + client = AsyncClient(flan_t5_xxl_url, hf_headers) with pytest.raises(ValidationError): await client.generate("test", max_new_tokens=10_000) @@ -120,8 +120,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers): @pytest.mark.asyncio -async def test_generate_stream_async_validation_error(bloom_url, hf_headers): - client = AsyncClient(bloom_url, hf_headers) +async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers): + client = AsyncClient(flan_t5_xxl_url, hf_headers) with pytest.raises(ValidationError): async for _ in client.generate_stream("test", max_new_tokens=10_000): pass diff --git a/clients/python/tests/test_api_inference.py b/clients/python/tests/test_inference_api.py similarity index 92% rename from clients/python/tests/test_api_inference.py rename to clients/python/tests/test_inference_api.py index 67a5dbea..dc744940 100644 --- a/clients/python/tests/test_api_inference.py +++ b/clients/python/tests/test_inference_api.py @@ -7,7 +7,7 @@ from text_generation import ( AsyncClient, ) from text_generation.errors import NotSupportedError -from text_generation.api_inference import get_supported_models +from text_generation.inference_api import get_supported_models def test_get_supported_models(): diff --git a/clients/python/text_generation/__init__.py b/clients/python/text_generation/__init__.py index 0ef7d9fc..88861b37 100644 --- a/clients/python/text_generation/__init__.py +++ b/clients/python/text_generation/__init__.py @@ -15,4 +15,4 @@ __version__ = "0.3.2" from text_generation.client import Client, AsyncClient -from text_generation.api_inference import InferenceAPIClient, InferenceAPIAsyncClient +from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient diff --git a/clients/python/text_generation/api_inference.py b/clients/python/text_generation/inference_api.py similarity index 100% rename from clients/python/text_generation/api_inference.py rename to clients/python/text_generation/inference_api.py