mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Having less logs in case of failure for checking CI more easily.
This commit is contained in:
parent
a7448661f7
commit
230b25165d
@ -867,7 +867,7 @@ class AsyncClient:
|
|||||||
async with ClientSession(
|
async with ClientSession(
|
||||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
) as session:
|
) as session:
|
||||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
async with session.post(self.base_url, json=request.model_dump()) as resp:
|
||||||
payload = await resp.json()
|
payload = await resp.json()
|
||||||
|
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# ruff: noqa: E402
|
# ruff: noqa: E402
|
||||||
|
from _pytest.fixtures import SubRequest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
@ -10,13 +11,13 @@ class SessionTimeoutFix(requests.Session):
|
|||||||
|
|
||||||
requests.sessions.Session = SessionTimeoutFix
|
requests.sessions.Session = SessionTimeoutFix
|
||||||
|
|
||||||
|
import warnings
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import shutil
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -72,6 +73,16 @@ def pytest_collection_modifyitems(config, items):
|
|||||||
item.add_marker(skip_release)
|
item.add_marker(skip_release)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def container_log(request: SubRequest):
|
||||||
|
error_log = request.getfixturevalue("error_log")
|
||||||
|
assert error_log is not None
|
||||||
|
yield
|
||||||
|
if request.session.testsfailed:
|
||||||
|
error_log.seek(0)
|
||||||
|
print(error_log.read(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
class ResponseComparator(JSONSnapshotExtension):
|
class ResponseComparator(JSONSnapshotExtension):
|
||||||
rtol = 0.2
|
rtol = 0.2
|
||||||
ignore_logprob = False
|
ignore_logprob = False
|
||||||
@ -278,8 +289,10 @@ class IgnoreLogProbResponseComparator(ResponseComparator):
|
|||||||
|
|
||||||
|
|
||||||
class LauncherHandle:
|
class LauncherHandle:
|
||||||
def __init__(self, port: int):
|
def __init__(self, port: int, error_log):
|
||||||
|
with warnings.catch_warnings(action="ignore"):
|
||||||
self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
|
self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
|
||||||
|
self.error_log = error_log
|
||||||
|
|
||||||
def _inner_health(self):
|
def _inner_health(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -288,6 +301,8 @@ class LauncherHandle:
|
|||||||
assert timeout > 0
|
assert timeout > 0
|
||||||
for _ in range(timeout):
|
for _ in range(timeout):
|
||||||
if not self._inner_health():
|
if not self._inner_health():
|
||||||
|
self.error_log.seek(0)
|
||||||
|
print(self.error_log.read(), file=sys.stderr)
|
||||||
raise RuntimeError("Launcher crashed")
|
raise RuntimeError("Launcher crashed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -295,12 +310,14 @@ class LauncherHandle:
|
|||||||
return
|
return
|
||||||
except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
|
except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
self.error_log.seek(0)
|
||||||
|
print(self.error_log.read(), file=sys.stderr)
|
||||||
raise RuntimeError("Health check failed")
|
raise RuntimeError("Health check failed")
|
||||||
|
|
||||||
|
|
||||||
class ContainerLauncherHandle(LauncherHandle):
|
class ContainerLauncherHandle(LauncherHandle):
|
||||||
def __init__(self, docker_client, container_name, port: int):
|
def __init__(self, docker_client, container_name, port: int, error_log):
|
||||||
super(ContainerLauncherHandle, self).__init__(port)
|
super().__init__(port, error_log)
|
||||||
self.docker_client = docker_client
|
self.docker_client = docker_client
|
||||||
self.container_name = container_name
|
self.container_name = container_name
|
||||||
|
|
||||||
@ -310,8 +327,8 @@ class ContainerLauncherHandle(LauncherHandle):
|
|||||||
|
|
||||||
|
|
||||||
class ProcessLauncherHandle(LauncherHandle):
|
class ProcessLauncherHandle(LauncherHandle):
|
||||||
def __init__(self, process, port: int):
|
def __init__(self, process, port: int, error_log):
|
||||||
super(ProcessLauncherHandle, self).__init__(port)
|
super().__init__(port, error_log)
|
||||||
self.process = process
|
self.process = process
|
||||||
|
|
||||||
def _inner_health(self) -> bool:
|
def _inner_health(self) -> bool:
|
||||||
@ -334,14 +351,13 @@ def ignore_logprob_response_snapshot(snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def event_loop():
|
def error_log():
|
||||||
loop = asyncio.get_event_loop()
|
with tempfile.TemporaryFile("w+") as tmp:
|
||||||
yield loop
|
yield tmp
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def launcher(event_loop):
|
async def launcher(error_log):
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def local_launcher(
|
def local_launcher(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -429,23 +445,20 @@ def launcher(event_loop):
|
|||||||
if attention is not None:
|
if attention is not None:
|
||||||
env["ATTENTION"] = attention
|
env["ATTENTION"] = attention
|
||||||
|
|
||||||
with tempfile.TemporaryFile("w+") as tmp:
|
# with tempfile.TemporaryFile("w+") as tmp:
|
||||||
# We'll output stdout/stderr to a temporary file. Using a pipe
|
# We'll output stdout/stderr to a temporary file. Using a pipe
|
||||||
# cause the process to block until stdout is read.
|
# cause the process to block until stdout is read.
|
||||||
with subprocess.Popen(
|
with subprocess.Popen(
|
||||||
args,
|
args,
|
||||||
stdout=tmp,
|
stdout=error_log,
|
||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
env=env,
|
env=env,
|
||||||
) as process:
|
) as process:
|
||||||
yield ProcessLauncherHandle(process, port)
|
yield ProcessLauncherHandle(process, port, error_log=error_log)
|
||||||
|
|
||||||
process.terminate()
|
process.terminate()
|
||||||
process.wait(60)
|
process.wait(60)
|
||||||
|
|
||||||
tmp.seek(0)
|
|
||||||
shutil.copyfileobj(tmp, sys.stderr)
|
|
||||||
|
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
del env["USE_FLASH_ATTENTION"]
|
del env["USE_FLASH_ATTENTION"]
|
||||||
|
|
||||||
@ -578,8 +591,21 @@ def launcher(event_loop):
|
|||||||
shm_size="1G",
|
shm_size="1G",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def pipe():
|
||||||
|
for log in container.logs(stream=True):
|
||||||
|
log = log.decode("utf-8")
|
||||||
|
error_log.write(log)
|
||||||
|
|
||||||
|
# Start looping to pipe the logs
|
||||||
|
import threading
|
||||||
|
|
||||||
|
t = threading.Thread(target=pipe, args=())
|
||||||
|
t.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield ContainerLauncherHandle(client, container.name, port)
|
yield ContainerLauncherHandle(
|
||||||
|
client, container.name, port, error_log=error_log
|
||||||
|
)
|
||||||
|
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
del env["USE_FLASH_ATTENTION"]
|
del env["USE_FLASH_ATTENTION"]
|
||||||
@ -590,9 +616,6 @@ def launcher(event_loop):
|
|||||||
except NotFound:
|
except NotFound:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
container_output = container.logs().decode("utf-8")
|
|
||||||
print(container_output, file=sys.stderr)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
container.remove()
|
container.remove()
|
||||||
|
Loading…
Reference in New Issue
Block a user