mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
wip
This commit is contained in:
parent
f58f0a0364
commit
338d30b43b
74
integration-tests/conftest.py
Normal file
74
integration-tests/conftest.py
Normal file
@ -0,0 +1,74 @@
|
||||
import pytest
|
||||
import subprocess
|
||||
import time
|
||||
import contextlib
|
||||
|
||||
from text_generation import Client
|
||||
from typing import Optional
|
||||
from requests import ConnectionError
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def launcher(model_id: str, num_shard: Optional[int] = None, quantize: bool = False):
|
||||
port = 9999
|
||||
master_port = 19999
|
||||
|
||||
shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server"
|
||||
|
||||
args = [
|
||||
"text-generation-launcher",
|
||||
"--model-id",
|
||||
model_id,
|
||||
"--port",
|
||||
str(port),
|
||||
"--master-port",
|
||||
str(master_port),
|
||||
"--shard-uds-path",
|
||||
shard_uds_path,
|
||||
]
|
||||
|
||||
if num_shard is not None:
|
||||
args.extend(["--num-shard", num_shard])
|
||||
if quantize:
|
||||
args.append("--quantize")
|
||||
|
||||
with subprocess.Popen(
|
||||
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
) as process:
|
||||
client = Client(f"http://localhost:{port}")
|
||||
|
||||
for _ in range(60):
|
||||
launcher_output = process.stdout.read1().decode("utf-8")
|
||||
print(launcher_output)
|
||||
|
||||
exit_code = process.poll()
|
||||
if exit_code is not None:
|
||||
launcher_error = process.stderr.read1().decode("utf-8")
|
||||
print(launcher_error)
|
||||
raise RuntimeError(
|
||||
f"text-generation-launcher terminated with exit code {exit_code}"
|
||||
)
|
||||
|
||||
try:
|
||||
client.generate("test", max_new_tokens=1)
|
||||
break
|
||||
except ConnectionError:
|
||||
time.sleep(1)
|
||||
|
||||
yield client
|
||||
|
||||
process.stdout.close()
|
||||
process.stderr.close()
|
||||
process.terminate()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def bloom_560m():
|
||||
with launcher("bigscience/bloom-560m") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def bloom_560m_multi():
|
||||
with launcher("bigscience/bloom-560m", num_shard=2) as client:
|
||||
yield client
|
6
integration-tests/models/test_bloom_560m.py
Normal file
6
integration-tests/models/test_bloom_560m.py
Normal file
@ -0,0 +1,6 @@
|
||||
def test_bloom_560m(bloom_560m, snapshot):
|
||||
response = bloom_560m.generate("Test request")
|
||||
# response_multi = bloom_560m_multi.generate("Test request")
|
||||
# assert response == response_multi == snapshot
|
||||
assert response == snapshot
|
||||
|
3
integration-tests/requirements.txt
Normal file
3
integration-tests/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
syrupy==0.4.2
|
||||
text-generation==0.5.1
|
||||
pytest==7.3.1
|
Loading…
Reference in New Issue
Block a user