working local

This commit is contained in:
OlivierDehaene 2023-05-03 11:44:58 +02:00
parent 338d30b43b
commit d9e6c514b5
3 changed files with 134 additions and 58 deletions

View File

@ -1,74 +1,106 @@
import pytest
import subprocess
import time
import contextlib
import pytest
import asyncio
import os
from text_generation import Client
from typing import Optional
from requests import ConnectionError
from typing import Optional, List
from aiohttp import ClientConnectorError
from text_generation import AsyncClient
from text_generation.types import Response
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
@contextlib.contextmanager
def launcher(model_id: str, num_shard: Optional[int] = None, quantize: bool = False):
port = 9999
master_port = 19999
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server"
args = [
"text-generation-launcher",
"--model-id",
model_id,
"--port",
str(port),
"--master-port",
str(master_port),
"--shard-uds-path",
shard_uds_path,
]
@pytest.fixture(scope="module")
def launcher(event_loop):
@contextlib.asynccontextmanager
async def local_launcher_inner(model_id: str, num_shard: Optional[int] = None, quantize: bool = False):
port = 9999
master_port = 19999
if num_shard is not None:
args.extend(["--num-shard", num_shard])
if quantize:
args.append("--quantize")
shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server"
with subprocess.Popen(
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
) as process:
client = Client(f"http://localhost:{port}")
args = [
"text-generation-launcher",
"--model-id",
model_id,
"--port",
str(port),
"--master-port",
str(master_port),
"--shard-uds-path",
shard_uds_path,
]
if num_shard is not None:
args.extend(["--num-shard", str(num_shard)])
if quantize:
args.append("--quantize")
with subprocess.Popen(
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
) as process:
client = AsyncClient(f"http://localhost:{port}")
healthy = False
for _ in range(60):
launcher_output = process.stdout.read1().decode("utf-8")
print(launcher_output)
exit_code = process.poll()
if exit_code is not None:
launcher_error = process.stderr.read1().decode("utf-8")
print(launcher_error)
raise RuntimeError(
f"text-generation-launcher terminated with exit code {exit_code}"
)
try:
await client.generate("test", max_new_tokens=1)
healthy = True
break
except ClientConnectorError:
time.sleep(1)
if healthy:
yield client
process.terminate()
for _ in range(60):
exit_code = process.wait(1)
if exit_code is not None:
break
for _ in range(60):
launcher_output = process.stdout.read1().decode("utf-8")
print(launcher_output)
exit_code = process.poll()
if exit_code is not None:
launcher_error = process.stderr.read1().decode("utf-8")
print(launcher_error)
raise RuntimeError(
f"text-generation-launcher terminated with exit code {exit_code}"
)
process.stdout.close()
process.stderr.close()
try:
client.generate("test", max_new_tokens=1)
break
except ConnectionError:
time.sleep(1)
if not healthy:
raise RuntimeError(f"unable to start model {model_id} with command: {' '.join(args)}")
yield client
process.stdout.close()
process.stderr.close()
process.terminate()
return launcher_inner
@pytest.fixture(scope="session")
def bloom_560m():
with launcher("bigscience/bloom-560m") as client:
yield client
@pytest.fixture(scope="module")
def generate_load():
async def generate_load_inner(client: AsyncClient, prompt: str, max_new_tokens: int, n: int) -> List[Response]:
futures = [client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)]
results = await asyncio.gather(*futures)
return results
@pytest.fixture(scope="session")
def bloom_560m_multi():
with launcher("bigscience/bloom-560m", num_shard=2) as client:
yield client
return generate_load_inner

View File

@ -1,6 +1,25 @@
def test_bloom_560m(bloom_560m, snapshot):
response = bloom_560m.generate("Test request")
# response_multi = bloom_560m_multi.generate("Test request")
# assert response == response_multi == snapshot
import pytest
@pytest.fixture(scope="module")
async def bloom_560m(launcher):
async with launcher("bigscience/bloom-560m") as client:
yield client
@pytest.mark.asyncio
async def test_bloom_560m(bloom_560m, snapshot):
response = await bloom_560m.generate("Test request", max_new_tokens=10)
assert response.details.generated_tokens == 10
assert response == snapshot
@pytest.mark.asyncio
async def test_bloom_560m_load(bloom_560m, generate_load, snapshot):
responses = await generate_load(bloom_560m, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4
assert responses == snapshot

View File

@ -0,0 +1,25 @@
import pytest
@pytest.fixture(scope="module")
async def bloom_560m_sharded(launcher):
async with launcher("bigscience/bloom-560m", num_shard=2) as client:
yield client
@pytest.mark.asyncio
async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot):
response = await bloom_560m_sharded.generate("Test request", max_new_tokens=10)
assert response.details.generated_tokens == 10
assert response == snapshot
@pytest.mark.asyncio
async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapshot):
responses = await generate_load(bloom_560m_sharded, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4
assert responses == snapshot