mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
working local
This commit is contained in:
parent
338d30b43b
commit
d9e6c514b5
@ -1,74 +1,106 @@
|
|||||||
import pytest
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
from text_generation import Client
|
from typing import Optional, List
|
||||||
from typing import Optional
|
from aiohttp import ClientConnectorError
|
||||||
from requests import ConnectionError
|
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
from text_generation.types import Response
|
||||||
|
|
||||||
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@pytest.fixture(scope="module")
|
||||||
def launcher(model_id: str, num_shard: Optional[int] = None, quantize: bool = False):
|
def event_loop():
|
||||||
port = 9999
|
loop = asyncio.get_event_loop()
|
||||||
master_port = 19999
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server"
|
|
||||||
|
|
||||||
args = [
|
@pytest.fixture(scope="module")
|
||||||
"text-generation-launcher",
|
def launcher(event_loop):
|
||||||
"--model-id",
|
@contextlib.asynccontextmanager
|
||||||
model_id,
|
async def local_launcher_inner(model_id: str, num_shard: Optional[int] = None, quantize: bool = False):
|
||||||
"--port",
|
port = 9999
|
||||||
str(port),
|
master_port = 19999
|
||||||
"--master-port",
|
|
||||||
str(master_port),
|
|
||||||
"--shard-uds-path",
|
|
||||||
shard_uds_path,
|
|
||||||
]
|
|
||||||
|
|
||||||
if num_shard is not None:
|
shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server"
|
||||||
args.extend(["--num-shard", num_shard])
|
|
||||||
if quantize:
|
|
||||||
args.append("--quantize")
|
|
||||||
|
|
||||||
with subprocess.Popen(
|
args = [
|
||||||
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
"text-generation-launcher",
|
||||||
) as process:
|
"--model-id",
|
||||||
client = Client(f"http://localhost:{port}")
|
model_id,
|
||||||
|
"--port",
|
||||||
|
str(port),
|
||||||
|
"--master-port",
|
||||||
|
str(master_port),
|
||||||
|
"--shard-uds-path",
|
||||||
|
shard_uds_path,
|
||||||
|
]
|
||||||
|
|
||||||
|
if num_shard is not None:
|
||||||
|
args.extend(["--num-shard", str(num_shard)])
|
||||||
|
if quantize:
|
||||||
|
args.append("--quantize")
|
||||||
|
|
||||||
|
with subprocess.Popen(
|
||||||
|
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||||
|
) as process:
|
||||||
|
client = AsyncClient(f"http://localhost:{port}")
|
||||||
|
|
||||||
|
healthy = False
|
||||||
|
|
||||||
|
for _ in range(60):
|
||||||
|
launcher_output = process.stdout.read1().decode("utf-8")
|
||||||
|
print(launcher_output)
|
||||||
|
|
||||||
|
exit_code = process.poll()
|
||||||
|
if exit_code is not None:
|
||||||
|
launcher_error = process.stderr.read1().decode("utf-8")
|
||||||
|
print(launcher_error)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"text-generation-launcher terminated with exit code {exit_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.generate("test", max_new_tokens=1)
|
||||||
|
healthy = True
|
||||||
|
break
|
||||||
|
except ClientConnectorError:
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
if healthy:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
process.terminate()
|
||||||
|
|
||||||
|
for _ in range(60):
|
||||||
|
exit_code = process.wait(1)
|
||||||
|
if exit_code is not None:
|
||||||
|
break
|
||||||
|
|
||||||
for _ in range(60):
|
|
||||||
launcher_output = process.stdout.read1().decode("utf-8")
|
launcher_output = process.stdout.read1().decode("utf-8")
|
||||||
print(launcher_output)
|
print(launcher_output)
|
||||||
|
|
||||||
exit_code = process.poll()
|
process.stdout.close()
|
||||||
if exit_code is not None:
|
process.stderr.close()
|
||||||
launcher_error = process.stderr.read1().decode("utf-8")
|
|
||||||
print(launcher_error)
|
|
||||||
raise RuntimeError(
|
|
||||||
f"text-generation-launcher terminated with exit code {exit_code}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
if not healthy:
|
||||||
client.generate("test", max_new_tokens=1)
|
raise RuntimeError(f"unable to start model {model_id} with command: {' '.join(args)}")
|
||||||
break
|
|
||||||
except ConnectionError:
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
yield client
|
return launcher_inner
|
||||||
|
|
||||||
process.stdout.close()
|
|
||||||
process.stderr.close()
|
|
||||||
process.terminate()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="module")
|
||||||
def bloom_560m():
|
def generate_load():
|
||||||
with launcher("bigscience/bloom-560m") as client:
|
async def generate_load_inner(client: AsyncClient, prompt: str, max_new_tokens: int, n: int) -> List[Response]:
|
||||||
yield client
|
futures = [client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*futures)
|
||||||
|
return results
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
return generate_load_inner
|
||||||
def bloom_560m_multi():
|
|
||||||
with launcher("bigscience/bloom-560m", num_shard=2) as client:
|
|
||||||
yield client
|
|
||||||
|
@ -1,6 +1,25 @@
|
|||||||
def test_bloom_560m(bloom_560m, snapshot):
|
import pytest
|
||||||
response = bloom_560m.generate("Test request")
|
|
||||||
# response_multi = bloom_560m_multi.generate("Test request")
|
|
||||||
# assert response == response_multi == snapshot
|
@pytest.fixture(scope="module")
|
||||||
|
async def bloom_560m(launcher):
|
||||||
|
async with launcher("bigscience/bloom-560m") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bloom_560m(bloom_560m, snapshot):
|
||||||
|
response = await bloom_560m.generate("Test request", max_new_tokens=10)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
assert response == snapshot
|
assert response == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bloom_560m_load(bloom_560m, generate_load, snapshot):
|
||||||
|
responses = await generate_load(bloom_560m, "Test request", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
assert responses == snapshot
|
||||||
|
25
integration-tests/models/test_bloom_560m_sharded.py
Normal file
25
integration-tests/models/test_bloom_560m_sharded.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def bloom_560m_sharded(launcher):
|
||||||
|
async with launcher("bigscience/bloom-560m", num_shard=2) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot):
|
||||||
|
response = await bloom_560m_sharded.generate("Test request", max_new_tokens=10)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapshot):
|
||||||
|
responses = await generate_load(bloom_560m_sharded, "Test request", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
assert responses == snapshot
|
Loading…
Reference in New Issue
Block a user