mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
publish
This commit is contained in:
parent
6e9e194f33
commit
478d5c1403
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation"
|
||||
version = "0.3.2"
|
||||
version = "0.1.0"
|
||||
description = "Hugging Face Text Generation Python Client"
|
||||
license = "Apache-2.0"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
@ -9,6 +9,11 @@ def bloom_model():
|
||||
return "bigscience/bloom"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flan_t5_xxl():
|
||||
return "google/flan-t5-xxl"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_model():
|
||||
return "fake/model"
|
||||
@ -29,6 +34,11 @@ def bloom_url(base_url, bloom_model):
|
||||
return f"{base_url}/{bloom_model}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flan_t5_xxl_url(base_url, flan_t5_xxl):
|
||||
return f"{base_url}/{flan_t5_xxl}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_url(base_url, fake_model):
|
||||
return f"{base_url}/{fake_model}"
|
||||
|
@ -29,8 +29,8 @@ def test_generate_not_found(fake_url, hf_headers):
|
||||
client.generate("test")
|
||||
|
||||
|
||||
def test_generate_validation_error(bloom_url, hf_headers):
|
||||
client = Client(bloom_url, hf_headers)
|
||||
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
with pytest.raises(ValidationError):
|
||||
client.generate("test", max_new_tokens=10_000)
|
||||
|
||||
@ -56,8 +56,8 @@ def test_generate_stream_not_found(fake_url, hf_headers):
|
||||
list(client.generate_stream("test"))
|
||||
|
||||
|
||||
def test_generate_stream_validation_error(bloom_url, hf_headers):
|
||||
client = Client(bloom_url, hf_headers)
|
||||
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
with pytest.raises(ValidationError):
|
||||
list(client.generate_stream("test", max_new_tokens=10_000))
|
||||
|
||||
@ -89,8 +89,8 @@ async def test_generate_async_not_found(fake_url, hf_headers):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_async_validation_error(bloom_url, hf_headers):
|
||||
client = AsyncClient(bloom_url, hf_headers)
|
||||
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
with pytest.raises(ValidationError):
|
||||
await client.generate("test", max_new_tokens=10_000)
|
||||
|
||||
@ -120,8 +120,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_stream_async_validation_error(bloom_url, hf_headers):
|
||||
client = AsyncClient(bloom_url, hf_headers)
|
||||
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
with pytest.raises(ValidationError):
|
||||
async for _ in client.generate_stream("test", max_new_tokens=10_000):
|
||||
pass
|
||||
|
@ -7,7 +7,7 @@ from text_generation import (
|
||||
AsyncClient,
|
||||
)
|
||||
from text_generation.errors import NotSupportedError
|
||||
from text_generation.api_inference import get_supported_models
|
||||
from text_generation.inference_api import get_supported_models
|
||||
|
||||
|
||||
def test_get_supported_models():
|
@ -15,4 +15,4 @@
|
||||
__version__ = "0.3.2"
|
||||
|
||||
from text_generation.client import Client, AsyncClient
|
||||
from text_generation.api_inference import InferenceAPIClient, InferenceAPIAsyncClient
|
||||
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
||||
|
Loading…
Reference in New Issue
Block a user