mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
feat(tests): add snapshot testing
This commit is contained in:
parent
d9e6c514b5
commit
421372f271
@ -4,7 +4,10 @@ import contextlib
|
||||
import pytest
|
||||
import asyncio
|
||||
import os
|
||||
import docker
|
||||
|
||||
from datetime import datetime
|
||||
from docker.errors import NotFound
|
||||
from typing import Optional, List
|
||||
from aiohttp import ClientConnectorError
|
||||
|
||||
@ -23,8 +26,10 @@ def event_loop():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def launcher(event_loop):
|
||||
@contextlib.asynccontextmanager
|
||||
async def local_launcher_inner(model_id: str, num_shard: Optional[int] = None, quantize: bool = False):
|
||||
@contextlib.contextmanager
|
||||
def local_launcher(
|
||||
model_id: str, num_shard: Optional[int] = None, quantize: bool = False
|
||||
):
|
||||
port = 9999
|
||||
master_port = 19999
|
||||
|
||||
@ -48,40 +53,12 @@ def launcher(event_loop):
|
||||
args.append("--quantize")
|
||||
|
||||
with subprocess.Popen(
|
||||
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
) as process:
|
||||
client = AsyncClient(f"http://localhost:{port}")
|
||||
|
||||
healthy = False
|
||||
|
||||
for _ in range(60):
|
||||
launcher_output = process.stdout.read1().decode("utf-8")
|
||||
print(launcher_output)
|
||||
|
||||
exit_code = process.poll()
|
||||
if exit_code is not None:
|
||||
launcher_error = process.stderr.read1().decode("utf-8")
|
||||
print(launcher_error)
|
||||
raise RuntimeError(
|
||||
f"text-generation-launcher terminated with exit code {exit_code}"
|
||||
)
|
||||
|
||||
try:
|
||||
await client.generate("test", max_new_tokens=1)
|
||||
healthy = True
|
||||
break
|
||||
except ClientConnectorError:
|
||||
time.sleep(1)
|
||||
|
||||
if healthy:
|
||||
yield client
|
||||
yield AsyncClient(f"http://localhost:{port}")
|
||||
|
||||
process.terminate()
|
||||
|
||||
for _ in range(60):
|
||||
exit_code = process.wait(1)
|
||||
if exit_code is not None:
|
||||
break
|
||||
process.wait(60)
|
||||
|
||||
launcher_output = process.stdout.read1().decode("utf-8")
|
||||
print(launcher_output)
|
||||
@ -89,16 +66,65 @@ def launcher(event_loop):
|
||||
process.stdout.close()
|
||||
process.stderr.close()
|
||||
|
||||
if not healthy:
|
||||
raise RuntimeError(f"unable to start model {model_id} with command: {' '.join(args)}")
|
||||
@contextlib.contextmanager
|
||||
def docker_launcher(
|
||||
model_id: str, num_shard: Optional[int] = None, quantize: bool = False
|
||||
):
|
||||
port = 9999
|
||||
|
||||
return launcher_inner
|
||||
args = ["--model-id", model_id, "--env"]
|
||||
|
||||
if num_shard is not None:
|
||||
args.extend(["--num-shard", str(num_shard)])
|
||||
if quantize:
|
||||
args.append("--quantize")
|
||||
|
||||
client = docker.from_env()
|
||||
|
||||
container_name = f"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}"
|
||||
|
||||
try:
|
||||
container = client.containers.get(container_name)
|
||||
container.stop()
|
||||
container.wait()
|
||||
except NotFound:
|
||||
pass
|
||||
|
||||
gpu_count = num_shard if num_shard is not None else 1
|
||||
|
||||
container = client.containers.run(
|
||||
DOCKER_IMAGE,
|
||||
command=args,
|
||||
name=container_name,
|
||||
auto_remove=True,
|
||||
detach=True,
|
||||
device_requests=[
|
||||
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
|
||||
],
|
||||
volumes=["/data:/data"],
|
||||
ports={"80/tcp": port},
|
||||
)
|
||||
|
||||
yield AsyncClient(f"http://localhost:{port}")
|
||||
|
||||
container.stop()
|
||||
|
||||
container_output = container.logs().decode("utf-8")
|
||||
print(container_output)
|
||||
|
||||
if DOCKER_IMAGE is not None:
|
||||
return docker_launcher
|
||||
return local_launcher
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def generate_load():
|
||||
async def generate_load_inner(client: AsyncClient, prompt: str, max_new_tokens: int, n: int) -> List[Response]:
|
||||
futures = [client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)]
|
||||
async def generate_load_inner(
|
||||
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
|
||||
) -> List[Response]:
|
||||
futures = [
|
||||
client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*futures)
|
||||
return results
|
||||
|
94
integration-tests/models/__snapshots__/test_bloom_560m.ambr
Normal file
94
integration-tests/models/__snapshots__/test_bloom_560m.ambr
Normal file
@ -0,0 +1,94 @@
|
||||
# serializer version: 1
|
||||
# name: test_bloom_560m
|
||||
dict({
|
||||
'details': dict({
|
||||
'best_of_sequences': None,
|
||||
'finish_reason': <FinishReason.Length: 'length'>,
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 10264,
|
||||
'logprob': None,
|
||||
'text': 'Test',
|
||||
}),
|
||||
dict({
|
||||
'id': 8821,
|
||||
'logprob': -11.3125,
|
||||
'text': ' request',
|
||||
}),
|
||||
]),
|
||||
'seed': None,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 11,
|
||||
'logprob': -2.859375,
|
||||
'special': False,
|
||||
'text': '(',
|
||||
}),
|
||||
dict({
|
||||
'id': 5,
|
||||
'logprob': -2.34375,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
}),
|
||||
dict({
|
||||
'id': 1587,
|
||||
'logprob': -3.25,
|
||||
'special': False,
|
||||
'text': 'get',
|
||||
}),
|
||||
dict({
|
||||
'id': 5,
|
||||
'logprob': -1.828125,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
}),
|
||||
dict({
|
||||
'id': 15,
|
||||
'logprob': -0.35546875,
|
||||
'special': False,
|
||||
'text': ',',
|
||||
}),
|
||||
dict({
|
||||
'id': 567,
|
||||
'logprob': -2.4375,
|
||||
'special': False,
|
||||
'text': ' "',
|
||||
}),
|
||||
dict({
|
||||
'id': 17,
|
||||
'logprob': -4.40625,
|
||||
'special': False,
|
||||
'text': '.',
|
||||
}),
|
||||
dict({
|
||||
'id': 5,
|
||||
'logprob': -2.46875,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
}),
|
||||
dict({
|
||||
'id': 12,
|
||||
'logprob': -1.6015625,
|
||||
'special': False,
|
||||
'text': ')',
|
||||
}),
|
||||
dict({
|
||||
'id': 30,
|
||||
'logprob': -0.14550781,
|
||||
'special': False,
|
||||
'text': ';',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': '("get", ".");',
|
||||
})
|
||||
# ---
|
||||
# name: test_bloom_560m_load
|
||||
list([
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.3125)], tokens=[Token(id=11, text='(', logprob=-2.859375, special=False), Token(id=5, text='"', logprob=-2.421875, special=False), Token(id=1587, text='get', logprob=-3.109375, special=False), Token(id=5, text='"', logprob=-2.046875, special=False), Token(id=15, text=',', logprob=-0.37890625, special=False), Token(id=567, text=' "', logprob=-2.65625, special=False), Token(id=17, text='.', logprob=-4.28125, special=False), Token(id=5, text='"', logprob=-2.1875, special=False), Token(id=12, text=')', logprob=-1.6328125, special=False), Token(id=30, text=';', logprob=-0.16113281, special=False)], best_of_sequences=None)),
|
||||
])
|
||||
# ---
|
@ -0,0 +1,94 @@
|
||||
# serializer version: 1
|
||||
# name: test_bloom_560m_sharded
|
||||
dict({
|
||||
'details': dict({
|
||||
'best_of_sequences': None,
|
||||
'finish_reason': <FinishReason.Length: 'length'>,
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 10264,
|
||||
'logprob': None,
|
||||
'text': 'Test',
|
||||
}),
|
||||
dict({
|
||||
'id': 8821,
|
||||
'logprob': -11.375,
|
||||
'text': ' request',
|
||||
}),
|
||||
]),
|
||||
'seed': None,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 11,
|
||||
'logprob': -2.734375,
|
||||
'special': False,
|
||||
'text': '(',
|
||||
}),
|
||||
dict({
|
||||
'id': 5,
|
||||
'logprob': -1.9765625,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
}),
|
||||
dict({
|
||||
'id': 1587,
|
||||
'logprob': -3.140625,
|
||||
'special': False,
|
||||
'text': 'get',
|
||||
}),
|
||||
dict({
|
||||
'id': 5,
|
||||
'logprob': -3.515625,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
}),
|
||||
dict({
|
||||
'id': 15,
|
||||
'logprob': -0.37304688,
|
||||
'special': False,
|
||||
'text': ',',
|
||||
}),
|
||||
dict({
|
||||
'id': 567,
|
||||
'logprob': -2.6875,
|
||||
'special': False,
|
||||
'text': ' "',
|
||||
}),
|
||||
dict({
|
||||
'id': 17,
|
||||
'logprob': -4.65625,
|
||||
'special': False,
|
||||
'text': '.',
|
||||
}),
|
||||
dict({
|
||||
'id': 5,
|
||||
'logprob': -2.28125,
|
||||
'special': False,
|
||||
'text': '"',
|
||||
}),
|
||||
dict({
|
||||
'id': 12,
|
||||
'logprob': -1.6640625,
|
||||
'special': False,
|
||||
'text': ')',
|
||||
}),
|
||||
dict({
|
||||
'id': 30,
|
||||
'logprob': -0.14355469,
|
||||
'special': False,
|
||||
'text': ';',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': '("get", ".");',
|
||||
})
|
||||
# ---
|
||||
# name: test_bloom_560m_sharded_load
|
||||
list([
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text='("get", ".");', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=10264, text='Test', logprob=None), PrefillToken(id=8821, text=' request', logprob=-11.375)], tokens=[Token(id=11, text='(', logprob=-2.734375, special=False), Token(id=5, text='"', logprob=-2.328125, special=False), Token(id=1587, text='get', logprob=-2.953125, special=False), Token(id=5, text='"', logprob=-3.546875, special=False), Token(id=15, text=',', logprob=-0.359375, special=False), Token(id=567, text=' "', logprob=-2.484375, special=False), Token(id=17, text='.', logprob=-4.0625, special=False), Token(id=5, text='"', logprob=-2.46875, special=False), Token(id=12, text=')', logprob=-1.671875, special=False), Token(id=30, text=';', logprob=-0.18164062, special=False)], best_of_sequences=None)),
|
||||
])
|
||||
# ---
|
@ -1,24 +1,29 @@
|
||||
import pytest
|
||||
|
||||
from utils import health_check
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def bloom_560m(launcher):
|
||||
async with launcher("bigscience/bloom-560m") as client:
|
||||
def bloom_560(launcher):
|
||||
with launcher("bigscience/bloom-560m") as client:
|
||||
yield client
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bloom_560m(bloom_560m, snapshot):
|
||||
response = await bloom_560m.generate("Test request", max_new_tokens=10)
|
||||
async def test_bloom_560m(bloom_560, snapshot):
|
||||
await health_check(bloom_560, 60)
|
||||
|
||||
response = await bloom_560.generate("Test request", max_new_tokens=10)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bloom_560m_load(bloom_560m, generate_load, snapshot):
|
||||
responses = await generate_load(bloom_560m, "Test request", max_new_tokens=10, n=4)
|
||||
async def test_bloom_560m_load(bloom_560, generate_load, snapshot):
|
||||
await health_check(bloom_560, 60)
|
||||
|
||||
responses = await generate_load(bloom_560, "Test request", max_new_tokens=10, n=4)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
||||
|
@ -1,15 +1,18 @@
|
||||
import pytest
|
||||
|
||||
from utils import health_check
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def bloom_560m_sharded(launcher):
|
||||
async with launcher("bigscience/bloom-560m", num_shard=2) as client:
|
||||
def bloom_560m_sharded(launcher):
|
||||
with launcher("bigscience/bloom-560m", num_shard=2) as client:
|
||||
yield client
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot):
|
||||
await health_check(bloom_560m_sharded, 60)
|
||||
|
||||
response = await bloom_560m_sharded.generate("Test request", max_new_tokens=10)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
@ -18,7 +21,11 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapshot):
|
||||
responses = await generate_load(bloom_560m_sharded, "Test request", max_new_tokens=10, n=4)
|
||||
await health_check(bloom_560m_sharded, 60)
|
||||
|
||||
responses = await generate_load(
|
||||
bloom_560m_sharded, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
||||
|
15
integration-tests/models/utils.py
Normal file
15
integration-tests/models/utils.py
Normal file
@ -0,0 +1,15 @@
|
||||
import time
|
||||
|
||||
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||
from text_generation import AsyncClient
|
||||
|
||||
|
||||
async def health_check(client: AsyncClient, timeout: int = 60):
|
||||
assert timeout > 0
|
||||
for _ in range(timeout):
|
||||
try:
|
||||
await client.generate("test")
|
||||
return
|
||||
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
|
||||
time.sleep(1)
|
||||
raise e
|
@ -1,3 +1,5 @@
|
||||
syrupy==0.4.2
|
||||
syrupy
|
||||
text-generation==0.5.1
|
||||
pytest==7.3.1
|
||||
pytest
|
||||
pytest-asyncio==0.17.2
|
||||
docker
|
Loading…
Reference in New Issue
Block a user