diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 154d1c8f..8caa7158 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -1,74 +1,106 @@ -import pytest import subprocess import time import contextlib +import pytest +import asyncio +import os -from text_generation import Client -from typing import Optional -from requests import ConnectionError +from typing import Optional, List +from aiohttp import ClientConnectorError + +from text_generation import AsyncClient +from text_generation.types import Response + +DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) -@contextlib.contextmanager -def launcher(model_id: str, num_shard: Optional[int] = None, quantize: bool = False): - port = 9999 - master_port = 19999 +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() - shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server" - args = [ - "text-generation-launcher", - "--model-id", - model_id, - "--port", - str(port), - "--master-port", - str(master_port), - "--shard-uds-path", - shard_uds_path, - ] +@pytest.fixture(scope="module") +def launcher(event_loop): + @contextlib.asynccontextmanager + async def local_launcher_inner(model_id: str, num_shard: Optional[int] = None, quantize: bool = False): + port = 9999 + master_port = 19999 - if num_shard is not None: - args.extend(["--num-shard", num_shard]) - if quantize: - args.append("--quantize") + shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server" - with subprocess.Popen( - args, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) as process: - client = Client(f"http://localhost:{port}") + args = [ + "text-generation-launcher", + "--model-id", + model_id, + "--port", + str(port), + "--master-port", + str(master_port), + "--shard-uds-path", + shard_uds_path, + ] + + if num_shard is not None: + args.extend(["--num-shard", str(num_shard)]) + if quantize: + args.append("--quantize") + + with subprocess.Popen( + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) as process: + client = AsyncClient(f"http://localhost:{port}") + + healthy = False + + for _ in range(60): + launcher_output = process.stdout.read1().decode("utf-8") + print(launcher_output) + + exit_code = process.poll() + if exit_code is not None: + launcher_error = process.stderr.read1().decode("utf-8") + print(launcher_error) + raise RuntimeError( + f"text-generation-launcher terminated with exit code {exit_code}" + ) + + try: + await client.generate("test", max_new_tokens=1) + healthy = True + break + except ClientConnectorError: + time.sleep(1) + + if healthy: + yield client + + process.terminate() + + for _ in range(60): + exit_code = process.wait(1) + if exit_code is not None: + break - for _ in range(60): launcher_output = process.stdout.read1().decode("utf-8") print(launcher_output) - exit_code = process.poll() - if exit_code is not None: - launcher_error = process.stderr.read1().decode("utf-8") - print(launcher_error) - raise RuntimeError( - f"text-generation-launcher terminated with exit code {exit_code}" - ) + process.stdout.close() + process.stderr.close() - try: - client.generate("test", max_new_tokens=1) - break - except ConnectionError: - time.sleep(1) + if not healthy: + raise RuntimeError(f"unable to start model {model_id} with command: {' '.join(args)}") - yield client - - process.stdout.close() - process.stderr.close() - process.terminate() + return launcher_inner -@pytest.fixture(scope="session") -def bloom_560m(): - with launcher("bigscience/bloom-560m") as client: - yield client +@pytest.fixture(scope="module") +def generate_load(): + async def generate_load_inner(client: AsyncClient, prompt: str, max_new_tokens: int, n: int) -> List[Response]: + futures = [client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)] + results = await asyncio.gather(*futures) + return results -@pytest.fixture(scope="session") -def bloom_560m_multi(): - with launcher("bigscience/bloom-560m", num_shard=2) as client: - yield client + return generate_load_inner diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py index bfa0331f..d44cefc3 100644 --- a/integration-tests/models/test_bloom_560m.py +++ b/integration-tests/models/test_bloom_560m.py @@ -1,6 +1,25 @@ -def test_bloom_560m(bloom_560m, snapshot): - response = bloom_560m.generate("Test request") - # response_multi = bloom_560m_multi.generate("Test request") - # assert response == response_multi == snapshot +import pytest + + +@pytest.fixture(scope="module") +async def bloom_560m(launcher): + async with launcher("bigscience/bloom-560m") as client: + yield client + + + +@pytest.mark.asyncio +async def test_bloom_560m(bloom_560m, snapshot): + response = await bloom_560m.generate("Test request", max_new_tokens=10) + + assert response.details.generated_tokens == 10 assert response == snapshot + +@pytest.mark.asyncio +async def test_bloom_560m_load(bloom_560m, generate_load, snapshot): + responses = await generate_load(bloom_560m, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + + assert responses == snapshot diff --git a/integration-tests/models/test_bloom_560m_sharded.py b/integration-tests/models/test_bloom_560m_sharded.py new file mode 100644 index 00000000..d6febf24 --- /dev/null +++ b/integration-tests/models/test_bloom_560m_sharded.py @@ -0,0 +1,25 @@ +import pytest + + +@pytest.fixture(scope="module") +async def bloom_560m_sharded(launcher): + async with launcher("bigscience/bloom-560m", num_shard=2) as client: + yield client + + + +@pytest.mark.asyncio +async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot): + response = await bloom_560m_sharded.generate("Test request", max_new_tokens=10) + + assert response.details.generated_tokens == 10 + assert response == snapshot + + +@pytest.mark.asyncio +async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapshot): + responses = await generate_load(bloom_560m_sharded, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + + assert responses == snapshot