mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
publish
This commit is contained in:
parent
6e9e194f33
commit
478d5c1403
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation"
|
name = "text-generation"
|
||||||
version = "0.3.2"
|
version = "0.1.0"
|
||||||
description = "Hugging Face Text Generation Python Client"
|
description = "Hugging Face Text Generation Python Client"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
@ -9,6 +9,11 @@ def bloom_model():
|
|||||||
return "bigscience/bloom"
|
return "bigscience/bloom"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flan_t5_xxl():
|
||||||
|
return "google/flan-t5-xxl"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def fake_model():
|
def fake_model():
|
||||||
return "fake/model"
|
return "fake/model"
|
||||||
@ -29,6 +34,11 @@ def bloom_url(base_url, bloom_model):
|
|||||||
return f"{base_url}/{bloom_model}"
|
return f"{base_url}/{bloom_model}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flan_t5_xxl_url(base_url, flan_t5_xxl):
|
||||||
|
return f"{base_url}/{flan_t5_xxl}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def fake_url(base_url, fake_model):
|
def fake_url(base_url, fake_model):
|
||||||
return f"{base_url}/{fake_model}"
|
return f"{base_url}/{fake_model}"
|
||||||
|
@ -29,8 +29,8 @@ def test_generate_not_found(fake_url, hf_headers):
|
|||||||
client.generate("test")
|
client.generate("test")
|
||||||
|
|
||||||
|
|
||||||
def test_generate_validation_error(bloom_url, hf_headers):
|
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
client = Client(bloom_url, hf_headers)
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
client.generate("test", max_new_tokens=10_000)
|
client.generate("test", max_new_tokens=10_000)
|
||||||
|
|
||||||
@ -56,8 +56,8 @@ def test_generate_stream_not_found(fake_url, hf_headers):
|
|||||||
list(client.generate_stream("test"))
|
list(client.generate_stream("test"))
|
||||||
|
|
||||||
|
|
||||||
def test_generate_stream_validation_error(bloom_url, hf_headers):
|
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
client = Client(bloom_url, hf_headers)
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
list(client.generate_stream("test", max_new_tokens=10_000))
|
list(client.generate_stream("test", max_new_tokens=10_000))
|
||||||
|
|
||||||
@ -89,8 +89,8 @@ async def test_generate_async_not_found(fake_url, hf_headers):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_async_validation_error(bloom_url, hf_headers):
|
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
client = AsyncClient(bloom_url, hf_headers)
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
await client.generate("test", max_new_tokens=10_000)
|
await client.generate("test", max_new_tokens=10_000)
|
||||||
|
|
||||||
@ -120,8 +120,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_stream_async_validation_error(bloom_url, hf_headers):
|
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
client = AsyncClient(bloom_url, hf_headers)
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
async for _ in client.generate_stream("test", max_new_tokens=10_000):
|
async for _ in client.generate_stream("test", max_new_tokens=10_000):
|
||||||
pass
|
pass
|
||||||
|
@ -7,7 +7,7 @@ from text_generation import (
|
|||||||
AsyncClient,
|
AsyncClient,
|
||||||
)
|
)
|
||||||
from text_generation.errors import NotSupportedError
|
from text_generation.errors import NotSupportedError
|
||||||
from text_generation.api_inference import get_supported_models
|
from text_generation.inference_api import get_supported_models
|
||||||
|
|
||||||
|
|
||||||
def test_get_supported_models():
|
def test_get_supported_models():
|
@ -15,4 +15,4 @@
|
|||||||
__version__ = "0.3.2"
|
__version__ = "0.3.2"
|
||||||
|
|
||||||
from text_generation.client import Client, AsyncClient
|
from text_generation.client import Client, AsyncClient
|
||||||
from text_generation.api_inference import InferenceAPIClient, InferenceAPIAsyncClient
|
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
||||||
|
Loading…
Reference in New Issue
Block a user