diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 8caa7158..cc64c064 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -4,7 +4,10 @@ import contextlib import pytest import asyncio import os +import docker +from datetime import datetime +from docker.errors import NotFound from typing import Optional, List from aiohttp import ClientConnectorError @@ -23,8 +26,10 @@ def event_loop(): @pytest.fixture(scope="module") def launcher(event_loop): - @contextlib.asynccontextmanager - async def local_launcher_inner(model_id: str, num_shard: Optional[int] = None, quantize: bool = False): + @contextlib.contextmanager + def local_launcher( + model_id: str, num_shard: Optional[int] = None, quantize: bool = False + ): port = 9999 master_port = 19999 @@ -48,40 +53,12 @@ def launcher(event_loop): args.append("--quantize") with subprocess.Popen( - args, stdout=subprocess.PIPE, stderr=subprocess.PIPE + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) as process: - client = AsyncClient(f"http://localhost:{port}") - - healthy = False - - for _ in range(60): - launcher_output = process.stdout.read1().decode("utf-8") - print(launcher_output) - - exit_code = process.poll() - if exit_code is not None: - launcher_error = process.stderr.read1().decode("utf-8") - print(launcher_error) - raise RuntimeError( - f"text-generation-launcher terminated with exit code {exit_code}" - ) - - try: - await client.generate("test", max_new_tokens=1) - healthy = True - break - except ClientConnectorError: - time.sleep(1) - - if healthy: - yield client + yield AsyncClient(f"http://localhost:{port}") process.terminate() - - for _ in range(60): - exit_code = process.wait(1) - if exit_code is not None: - break + process.wait(60) launcher_output = process.stdout.read1().decode("utf-8") print(launcher_output) @@ -89,16 +66,65 @@ def launcher(event_loop): process.stdout.close() process.stderr.close() - if not healthy: - raise RuntimeError(f"unable to start model {model_id} with command: {' '.join(args)}") + @contextlib.contextmanager + def docker_launcher( + model_id: str, num_shard: Optional[int] = None, quantize: bool = False + ): + port = 9999 - return launcher_inner + args = ["--model-id", model_id, "--env"] + + if num_shard is not None: + args.extend(["--num-shard", str(num_shard)]) + if quantize: + args.append("--quantize") + + client = docker.from_env() + + container_name = f"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}" + + try: + container = client.containers.get(container_name) + container.stop() + container.wait() + except NotFound: + pass + + gpu_count = num_shard if num_shard is not None else 1 + + container = client.containers.run( + DOCKER_IMAGE, + command=args, + name=container_name, + auto_remove=True, + detach=True, + device_requests=[ + docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]]) + ], + volumes=["/data:/data"], + ports={"80/tcp": port}, + ) + + yield AsyncClient(f"http://localhost:{port}") + + container.stop() + + container_output = container.logs().decode("utf-8") + print(container_output) + + if DOCKER_IMAGE is not None: + return docker_launcher + return local_launcher @pytest.fixture(scope="module") def generate_load(): - async def generate_load_inner(client: AsyncClient, prompt: str, max_new_tokens: int, n: int) -> List[Response]: - futures = [client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)] + async def generate_load_inner( + client: AsyncClient, prompt: str, max_new_tokens: int, n: int + ) -> List[Response]: + futures = [ + client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n) + ] results = await asyncio.gather(*futures) return results diff --git a/integration-tests/models/__snapshots__/test_bloom_560m.ambr b/integration-tests/models/__snapshots__/test_bloom_560m.ambr new file mode 100644 index 00000000..d81b225e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_bloom_560m.ambr @@ -0,0 +1,94 @@ +# serializer version: 1 +# name: test_bloom_560m + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 10264, + 'logprob': None, + 'text': 'Test', + }), + dict({ + 'id': 8821, + 'logprob': -11.3125, + 'text': ' request', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 11, + 'logprob': -2.859375, + 'special': False, + 'text': '(', + }), + dict({ + 'id': 5, + 'logprob': -2.34375, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 1587, + 'logprob': -3.25, + 'special': False, + 'text': 'get', + }), + dict({ + 'id': 5, + 'logprob': -1.828125, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 15, + 'logprob': -0.35546875, + 'special': False, + 'text': ',', + }), + dict({ + 'id': 567, + 'logprob': -2.4375, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 17, + 'logprob': -4.40625, + 'special': False, + 'text': '.', + }), + dict({ + 'id': 5, + 'logprob': -2.46875, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 12, + 'logprob': -1.6015625, + 'special': False, + 'text': ')', + }), + dict({ + 'id': 30, + 'logprob': -0.14550781, + 'special': False, + 'text': ';', + }), + ]), + }), + 'generated_text': '("get", ".");', + }) +# --- +# name: test_bloom_560m_load + list([ + Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)), + Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)), + Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)), + Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)), + ]) +# --- diff --git a/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr b/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr new file mode 100644 index 00000000..7cf26255 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr @@ -0,0 +1,94 @@ +# serializer version: 1 +# name: test_bloom_560m_sharded + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 10264, + 'logprob': None, + 'text': 'Test', + }), + dict({ + 'id': 8821, + 'logprob': -11.375, + 'text': ' request', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 11, + 'logprob': -2.734375, + 'special': False, + 'text': '(', + }), + dict({ + 'id': 5, + 'logprob': -1.9765625, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 1587, + 'logprob': -3.140625, + 'special': False, + 'text': 'get', + }), + dict({ + 'id': 5, + 'logprob': -3.515625, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 15, + 'logprob': -0.37304688, + 'special': False, + 'text': ',', + }), + dict({ + 'id': 567, + 'logprob': -2.6875, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 17, + 'logprob': -4.65625, + 'special': False, + 'text': '.', + }), + dict({ + 'id': 5, + 'logprob': -2.28125, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 12, + 'logprob': -1.6640625, + 'special': False, + 'text': ')', + }), + dict({ + 'id': 30, + 'logprob': -0.14355469, + 'special': False, + 'text': ';', + }), + ]), + }), + 'generated_text': '("get", ".");', + }) +# --- +# name: test_bloom_560m_sharded_load + list([ + Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)), + Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)), + Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)), + Response(generated_text='("get", ".");', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)), + ]) +# --- diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py index d44cefc3..f29cdb2d 100644 --- a/integration-tests/models/test_bloom_560m.py +++ b/integration-tests/models/test_bloom_560m.py @@ -1,24 +1,29 @@ import pytest +from utils import health_check + @pytest.fixture(scope="module") -async def bloom_560m(launcher): - async with launcher("bigscience/bloom-560m") as client: +def bloom_560(launcher): + with launcher("bigscience/bloom-560m") as client: yield client - @pytest.mark.asyncio -async def test_bloom_560m(bloom_560m, snapshot): - response = await bloom_560m.generate("Test request", max_new_tokens=10) +async def test_bloom_560m(bloom_560, snapshot): + await health_check(bloom_560, 60) + + response = await bloom_560.generate("Test request", max_new_tokens=10) assert response.details.generated_tokens == 10 assert response == snapshot @pytest.mark.asyncio -async def test_bloom_560m_load(bloom_560m, generate_load, snapshot): - responses = await generate_load(bloom_560m, "Test request", max_new_tokens=10, n=4) +async def test_bloom_560m_load(bloom_560, generate_load, snapshot): + await health_check(bloom_560, 60) + + responses = await generate_load(bloom_560, "Test request", max_new_tokens=10, n=4) assert len(responses) == 4 diff --git a/integration-tests/models/test_bloom_560m_sharded.py b/integration-tests/models/test_bloom_560m_sharded.py index d6febf24..e3d1cbe0 100644 --- a/integration-tests/models/test_bloom_560m_sharded.py +++ b/integration-tests/models/test_bloom_560m_sharded.py @@ -1,15 +1,18 @@ import pytest +from utils import health_check + @pytest.fixture(scope="module") -async def bloom_560m_sharded(launcher): - async with launcher("bigscience/bloom-560m", num_shard=2) as client: +def bloom_560m_sharded(launcher): + with launcher("bigscience/bloom-560m", num_shard=2) as client: yield client - @pytest.mark.asyncio async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot): + await health_check(bloom_560m_sharded, 60) + response = await bloom_560m_sharded.generate("Test request", max_new_tokens=10) assert response.details.generated_tokens == 10 @@ -18,7 +21,11 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot): @pytest.mark.asyncio async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapshot): - responses = await generate_load(bloom_560m_sharded, "Test request", max_new_tokens=10, n=4) + await health_check(bloom_560m_sharded, 60) + + responses = await generate_load( + bloom_560m_sharded, "Test request", max_new_tokens=10, n=4 + ) assert len(responses) == 4 diff --git a/integration-tests/models/utils.py b/integration-tests/models/utils.py new file mode 100644 index 00000000..7ef04d7f --- /dev/null +++ b/integration-tests/models/utils.py @@ -0,0 +1,15 @@ +import time + +from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError +from text_generation import AsyncClient + + +async def health_check(client: AsyncClient, timeout: int = 60): + assert timeout > 0 + for _ in range(timeout): + try: + await client.generate("test") + return + except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e: + time.sleep(1) + raise e diff --git a/integration-tests/requirements.txt b/integration-tests/requirements.txt index cfcea059..9ecbb2ee 100644 --- a/integration-tests/requirements.txt +++ b/integration-tests/requirements.txt @@ -1,3 +1,5 @@ -syrupy==0.4.2 +syrupy text-generation==0.5.1 -pytest==7.3.1 \ No newline at end of file +pytest +pytest-asyncio==0.17.2 +docker \ No newline at end of file