feat(tests): add snapshot testing

This commit is contained in:
OlivierDehaene 2023-05-03 19:09:18 +02:00
parent d9e6c514b5
commit 421372f271
7 changed files with 294 additions and 51 deletions

View File

@ -4,7 +4,10 @@ import contextlib
import pytest import pytest
import asyncio import asyncio
import os import os
import docker
from datetime import datetime
from docker.errors import NotFound
from typing import Optional, List from typing import Optional, List
from aiohttp import ClientConnectorError from aiohttp import ClientConnectorError
@ -23,8 +26,10 @@ def event_loop():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def launcher(event_loop): def launcher(event_loop):
@contextlib.asynccontextmanager @contextlib.contextmanager
async def local_launcher_inner(model_id: str, num_shard: Optional[int] = None, quantize: bool = False): def local_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: bool = False
):
port = 9999 port = 9999
master_port = 19999 master_port = 19999
@ -50,38 +55,10 @@ def launcher(event_loop):
with subprocess.Popen( with subprocess.Popen(
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
) as process: ) as process:
client = AsyncClient(f"http://localhost:{port}") yield AsyncClient(f"http://localhost:{port}")
healthy = False
for _ in range(60):
launcher_output = process.stdout.read1().decode("utf-8")
print(launcher_output)
exit_code = process.poll()
if exit_code is not None:
launcher_error = process.stderr.read1().decode("utf-8")
print(launcher_error)
raise RuntimeError(
f"text-generation-launcher terminated with exit code {exit_code}"
)
try:
await client.generate("test", max_new_tokens=1)
healthy = True
break
except ClientConnectorError:
time.sleep(1)
if healthy:
yield client
process.terminate() process.terminate()
process.wait(60)
for _ in range(60):
exit_code = process.wait(1)
if exit_code is not None:
break
launcher_output = process.stdout.read1().decode("utf-8") launcher_output = process.stdout.read1().decode("utf-8")
print(launcher_output) print(launcher_output)
@ -89,16 +66,65 @@ def launcher(event_loop):
process.stdout.close() process.stdout.close()
process.stderr.close() process.stderr.close()
if not healthy: @contextlib.contextmanager
raise RuntimeError(f"unable to start model {model_id} with command: {' '.join(args)}") def docker_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: bool = False
):
port = 9999
return launcher_inner args = ["--model-id", model_id, "--env"]
if num_shard is not None:
args.extend(["--num-shard", str(num_shard)])
if quantize:
args.append("--quantize")
client = docker.from_env()
container_name = f"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}"
try:
container = client.containers.get(container_name)
container.stop()
container.wait()
except NotFound:
pass
gpu_count = num_shard if num_shard is not None else 1
container = client.containers.run(
DOCKER_IMAGE,
command=args,
name=container_name,
auto_remove=True,
detach=True,
device_requests=[
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
],
volumes=["/data:/data"],
ports={"80/tcp": port},
)
yield AsyncClient(f"http://localhost:{port}")
container.stop()
container_output = container.logs().decode("utf-8")
print(container_output)
if DOCKER_IMAGE is not None:
return docker_launcher
return local_launcher
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def generate_load(): def generate_load():
async def generate_load_inner(client: AsyncClient, prompt: str, max_new_tokens: int, n: int) -> List[Response]: async def generate_load_inner(
futures = [client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)] client: AsyncClient, prompt: str, max_new_tokens: int, n: int
) -> List[Response]:
futures = [
client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)
]
results = await asyncio.gather(*futures) results = await asyncio.gather(*futures)
return results return results

View File

@ -0,0 +1,94 @@
# serializer version: 1
# name: test_bloom_560m
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 10264,
'logprob': None,
'text': 'Test',
}),
dict({
'id': 8821,
'logprob': -11.3125,
'text': ' request',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 11,
'logprob': -2.859375,
'special': False,
'text': '(',
}),
dict({
'id': 5,
'logprob': -2.34375,
'special': False,
'text': '"',
}),
dict({
'id': 1587,
'logprob': -3.25,
'special': False,
'text': 'get',
}),
dict({
'id': 5,
'logprob': -1.828125,
'special': False,
'text': '"',
}),
dict({
'id': 15,
'logprob': -0.35546875,
'special': False,
'text': ',',
}),
dict({
'id': 567,
'logprob': -2.4375,
'special': False,
'text': ' "',
}),
dict({
'id': 17,
'logprob': -4.40625,
'special': False,
'text': '.',
}),
dict({
'id': 5,
'logprob': -2.46875,
'special': False,
'text': '"',
}),
dict({
'id': 12,
'logprob': -1.6015625,
'special': False,
'text': ')',
}),
dict({
'id': 30,
'logprob': -0.14550781,
'special': False,
'text': ';',
}),
]),
}),
'generated_text': '("get", ".");',
})
# ---
# name: test_bloom_560m_load
list([
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
])
# ---

View File

@ -0,0 +1,94 @@
# serializer version: 1
# name: test_bloom_560m_sharded
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 10264,
'logprob': None,
'text': 'Test',
}),
dict({
'id': 8821,
'logprob': -11.375,
'text': ' request',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 11,
'logprob': -2.734375,
'special': False,
'text': '(',
}),
dict({
'id': 5,
'logprob': -1.9765625,
'special': False,
'text': '"',
}),
dict({
'id': 1587,
'logprob': -3.140625,
'special': False,
'text': 'get',
}),
dict({
'id': 5,
'logprob': -3.515625,
'special': False,
'text': '"',
}),
dict({
'id': 15,
'logprob': -0.37304688,
'special': False,
'text': ',',
}),
dict({
'id': 567,
'logprob': -2.6875,
'special': False,
'text': ' "',
}),
dict({
'id': 17,
'logprob': -4.65625,
'special': False,
'text': '.',
}),
dict({
'id': 5,
'logprob': -2.28125,
'special': False,
'text': '"',
}),
dict({
'id': 12,
'logprob': -1.6640625,
'special': False,
'text': ')',
}),
dict({
'id': 30,
'logprob': -0.14355469,
'special': False,
'text': ';',
}),
]),
}),
'generated_text': '("get", ".");',
})
# ---
# name: test_bloom_560m_sharded_load
list([
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
])
# ---

View File

@ -1,24 +1,29 @@
import pytest import pytest
from utils import health_check
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
async def bloom_560m(launcher): def bloom_560(launcher):
async with launcher("bigscience/bloom-560m") as client: with launcher("bigscience/bloom-560m") as client:
yield client yield client
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m(bloom_560m, snapshot): async def test_bloom_560m(bloom_560, snapshot):
response = await bloom_560m.generate("Test request", max_new_tokens=10) await health_check(bloom_560, 60)
response = await bloom_560.generate("Test request", max_new_tokens=10)
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == snapshot assert response == snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_load(bloom_560m, generate_load, snapshot): async def test_bloom_560m_load(bloom_560, generate_load, snapshot):
responses = await generate_load(bloom_560m, "Test request", max_new_tokens=10, n=4) await health_check(bloom_560, 60)
responses = await generate_load(bloom_560, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4 assert len(responses) == 4

View File

@ -1,15 +1,18 @@
import pytest import pytest
from utils import health_check
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
async def bloom_560m_sharded(launcher): def bloom_560m_sharded(launcher):
async with launcher("bigscience/bloom-560m", num_shard=2) as client: with launcher("bigscience/bloom-560m", num_shard=2) as client:
yield client yield client
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot): async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot):
await health_check(bloom_560m_sharded, 60)
response = await bloom_560m_sharded.generate("Test request", max_new_tokens=10) response = await bloom_560m_sharded.generate("Test request", max_new_tokens=10)
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
@ -18,7 +21,11 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapshot): async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapshot):
responses = await generate_load(bloom_560m_sharded, "Test request", max_new_tokens=10, n=4) await health_check(bloom_560m_sharded, 60)
responses = await generate_load(
bloom_560m_sharded, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4 assert len(responses) == 4

View File

@ -0,0 +1,15 @@
import time
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from text_generation import AsyncClient
async def health_check(client: AsyncClient, timeout: int = 60):
assert timeout > 0
for _ in range(timeout):
try:
await client.generate("test")
return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
time.sleep(1)
raise e

View File

@ -1,3 +1,5 @@
syrupy==0.4.2 syrupy
text-generation==0.5.1 text-generation==0.5.1
pytest==7.3.1 pytest
pytest-asyncio==0.17.2
docker