feat(tests): add snapshot testing

This commit is contained in:
OlivierDehaene 2023-05-03 19:09:18 +02:00
parent d9e6c514b5
commit 421372f271
7 changed files with 294 additions and 51 deletions

View File

@ -4,7 +4,10 @@ import contextlib
import pytest
import asyncio
import os
import docker
from datetime import datetime
from docker.errors import NotFound
from typing import Optional, List
from aiohttp import ClientConnectorError
@ -23,8 +26,10 @@ def event_loop():
@pytest.fixture(scope="module")
def launcher(event_loop):
@contextlib.asynccontextmanager
async def local_launcher_inner(model_id: str, num_shard: Optional[int] = None, quantize: bool = False):
@contextlib.contextmanager
def local_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: bool = False
):
port = 9999
master_port = 19999
@ -48,40 +53,12 @@ def launcher(event_loop):
args.append("--quantize")
with subprocess.Popen(
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
) as process:
client = AsyncClient(f"http://localhost:{port}")
healthy = False
for _ in range(60):
launcher_output = process.stdout.read1().decode("utf-8")
print(launcher_output)
exit_code = process.poll()
if exit_code is not None:
launcher_error = process.stderr.read1().decode("utf-8")
print(launcher_error)
raise RuntimeError(
f"text-generation-launcher terminated with exit code {exit_code}"
)
try:
await client.generate("test", max_new_tokens=1)
healthy = True
break
except ClientConnectorError:
time.sleep(1)
if healthy:
yield client
yield AsyncClient(f"http://localhost:{port}")
process.terminate()
for _ in range(60):
exit_code = process.wait(1)
if exit_code is not None:
break
process.wait(60)
launcher_output = process.stdout.read1().decode("utf-8")
print(launcher_output)
@ -89,16 +66,65 @@ def launcher(event_loop):
process.stdout.close()
process.stderr.close()
if not healthy:
raise RuntimeError(f"unable to start model {model_id} with command: {' '.join(args)}")
@contextlib.contextmanager
def docker_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: bool = False
):
port = 9999
return launcher_inner
args = ["--model-id", model_id, "--env"]
if num_shard is not None:
args.extend(["--num-shard", str(num_shard)])
if quantize:
args.append("--quantize")
client = docker.from_env()
container_name = f"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}"
try:
container = client.containers.get(container_name)
container.stop()
container.wait()
except NotFound:
pass
gpu_count = num_shard if num_shard is not None else 1
container = client.containers.run(
DOCKER_IMAGE,
command=args,
name=container_name,
auto_remove=True,
detach=True,
device_requests=[
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
],
volumes=["/data:/data"],
ports={"80/tcp": port},
)
yield AsyncClient(f"http://localhost:{port}")
container.stop()
container_output = container.logs().decode("utf-8")
print(container_output)
if DOCKER_IMAGE is not None:
return docker_launcher
return local_launcher
@pytest.fixture(scope="module")
def generate_load():
async def generate_load_inner(client: AsyncClient, prompt: str, max_new_tokens: int, n: int) -> List[Response]:
futures = [client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)]
async def generate_load_inner(
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
) -> List[Response]:
futures = [
client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)
]
results = await asyncio.gather(*futures)
return results

View File

@ -0,0 +1,94 @@
# serializer version: 1
# name: test_bloom_560m
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 10264,
'logprob': None,
'text': 'Test',
}),
dict({
'id': 8821,
'logprob': -11.3125,
'text': ' request',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 11,
'logprob': -2.859375,
'special': False,
'text': '(',
}),
dict({
'id': 5,
'logprob': -2.34375,
'special': False,
'text': '"',
}),
dict({
'id': 1587,
'logprob': -3.25,
'special': False,
'text': 'get',
}),
dict({
'id': 5,
'logprob': -1.828125,
'special': False,
'text': '"',
}),
dict({
'id': 15,
'logprob': -0.35546875,
'special': False,
'text': ',',
}),
dict({
'id': 567,
'logprob': -2.4375,
'special': False,
'text': ' "',
}),
dict({
'id': 17,
'logprob': -4.40625,
'special': False,
'text': '.',
}),
dict({
'id': 5,
'logprob': -2.46875,
'special': False,
'text': '"',
}),
dict({
'id': 12,
'logprob': -1.6015625,
'special': False,
'text': ')',
}),
dict({
'id': 30,
'logprob': -0.14550781,
'special': False,
'text': ';',
}),
]),
}),
'generated_text': '("get", ".");',
})
# ---
# name: test_bloom_560m_load
list([
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
])
# ---

View File

@ -0,0 +1,94 @@
# serializer version: 1
# name: test_bloom_560m_sharded
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 10264,
'logprob': None,
'text': 'Test',
}),
dict({
'id': 8821,
'logprob': -11.375,
'text': ' request',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 11,
'logprob': -2.734375,
'special': False,
'text': '(',
}),
dict({
'id': 5,
'logprob': -1.9765625,
'special': False,
'text': '"',
}),
dict({
'id': 1587,
'logprob': -3.140625,
'special': False,
'text': 'get',
}),
dict({
'id': 5,
'logprob': -3.515625,
'special': False,
'text': '"',
}),
dict({
'id': 15,
'logprob': -0.37304688,
'special': False,
'text': ',',
}),
dict({
'id': 567,
'logprob': -2.6875,
'special': False,
'text': ' "',
}),
dict({
'id': 17,
'logprob': -4.65625,
'special': False,
'text': '.',
}),
dict({
'id': 5,
'logprob': -2.28125,
'special': False,
'text': '"',
}),
dict({
'id': 12,
'logprob': -1.6640625,
'special': False,
'text': ')',
}),
dict({
'id': 30,
'logprob': -0.14355469,
'special': False,
'text': ';',
}),
]),
}),
'generated_text': '("get", ".");',
})
# ---
# name: test_bloom_560m_sharded_load
list([
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
])
# ---

View File

@ -1,24 +1,29 @@
import pytest
from utils import health_check
@pytest.fixture(scope="module")
async def bloom_560m(launcher):
async with launcher("bigscience/bloom-560m") as client:
def bloom_560(launcher):
with launcher("bigscience/bloom-560m") as client:
yield client
@pytest.mark.asyncio
async def test_bloom_560m(bloom_560m, snapshot):
response = await bloom_560m.generate("Test request", max_new_tokens=10)
async def test_bloom_560m(bloom_560, snapshot):
await health_check(bloom_560, 60)
response = await bloom_560.generate("Test request", max_new_tokens=10)
assert response.details.generated_tokens == 10
assert response == snapshot
@pytest.mark.asyncio
async def test_bloom_560m_load(bloom_560m, generate_load, snapshot):
responses = await generate_load(bloom_560m, "Test request", max_new_tokens=10, n=4)
async def test_bloom_560m_load(bloom_560, generate_load, snapshot):
await health_check(bloom_560, 60)
responses = await generate_load(bloom_560, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4

View File

@ -1,15 +1,18 @@
import pytest
from utils import health_check
@pytest.fixture(scope="module")
async def bloom_560m_sharded(launcher):
async with launcher("bigscience/bloom-560m", num_shard=2) as client:
def bloom_560m_sharded(launcher):
with launcher("bigscience/bloom-560m", num_shard=2) as client:
yield client
@pytest.mark.asyncio
async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot):
await health_check(bloom_560m_sharded, 60)
response = await bloom_560m_sharded.generate("Test request", max_new_tokens=10)
assert response.details.generated_tokens == 10
@ -18,7 +21,11 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot):
@pytest.mark.asyncio
async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapshot):
responses = await generate_load(bloom_560m_sharded, "Test request", max_new_tokens=10, n=4)
await health_check(bloom_560m_sharded, 60)
responses = await generate_load(
bloom_560m_sharded, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4

View File

@ -0,0 +1,15 @@
import time
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from text_generation import AsyncClient
async def health_check(client: AsyncClient, timeout: int = 60):
assert timeout > 0
for _ in range(timeout):
try:
await client.generate("test")
return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
time.sleep(1)
raise e

View File

@ -1,3 +1,5 @@
syrupy==0.4.2
syrupy
text-generation==0.5.1
pytest==7.3.1
pytest
pytest-asyncio==0.17.2
docker