From 338d30b43bb20074b2aba2adae51dd03ffb1025c Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 26 Apr 2023 00:52:44 +0200 Subject: [PATCH] wip --- integration-tests/conftest.py | 74 +++++++++++++++++++++ integration-tests/models/test_bloom_560m.py | 6 ++ integration-tests/requirements.txt | 3 + 3 files changed, 83 insertions(+) create mode 100644 integration-tests/conftest.py create mode 100644 integration-tests/models/test_bloom_560m.py create mode 100644 integration-tests/requirements.txt diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py new file mode 100644 index 00000000..154d1c8f --- /dev/null +++ b/integration-tests/conftest.py @@ -0,0 +1,74 @@ +import pytest +import subprocess +import time +import contextlib + +from text_generation import Client +from typing import Optional +from requests import ConnectionError + + +@contextlib.contextmanager +def launcher(model_id: str, num_shard: Optional[int] = None, quantize: bool = False): + port = 9999 + master_port = 19999 + + shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server" + + args = [ + "text-generation-launcher", + "--model-id", + model_id, + "--port", + str(port), + "--master-port", + str(master_port), + "--shard-uds-path", + shard_uds_path, + ] + + if num_shard is not None: + args.extend(["--num-shard", num_shard]) + if quantize: + args.append("--quantize") + + with subprocess.Popen( + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) as process: + client = Client(f"http://localhost:{port}") + + for _ in range(60): + launcher_output = process.stdout.read1().decode("utf-8") + print(launcher_output) + + exit_code = process.poll() + if exit_code is not None: + launcher_error = process.stderr.read1().decode("utf-8") + print(launcher_error) + raise RuntimeError( + f"text-generation-launcher terminated with exit code {exit_code}" + ) + + try: + client.generate("test", max_new_tokens=1) + break + except ConnectionError: + time.sleep(1) + + yield client + + process.stdout.close() + process.stderr.close() + process.terminate() + + +@pytest.fixture(scope="session") +def bloom_560m(): + with launcher("bigscience/bloom-560m") as client: + yield client + + +@pytest.fixture(scope="session") +def bloom_560m_multi(): + with launcher("bigscience/bloom-560m", num_shard=2) as client: + yield client diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py new file mode 100644 index 00000000..bfa0331f --- /dev/null +++ b/integration-tests/models/test_bloom_560m.py @@ -0,0 +1,6 @@ +def test_bloom_560m(bloom_560m, snapshot): + response = bloom_560m.generate("Test request") + # response_multi = bloom_560m_multi.generate("Test request") + # assert response == response_multi == snapshot + assert response == snapshot + diff --git a/integration-tests/requirements.txt b/integration-tests/requirements.txt new file mode 100644 index 00000000..cfcea059 --- /dev/null +++ b/integration-tests/requirements.txt @@ -0,0 +1,3 @@ +syrupy==0.4.2 +text-generation==0.5.1 +pytest==7.3.1 \ No newline at end of file