Having less logs in case of failure for checking CI more easily.

This commit is contained in:
Nicolas Patry 2025-02-19 11:56:34 +01:00
parent a7448661f7
commit 230b25165d
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
2 changed files with 53 additions and 30 deletions

View File

@ -867,7 +867,7 @@ class AsyncClient:
async with ClientSession( async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session: ) as session:
async with session.post(self.base_url, json=request.dict()) as resp: async with session.post(self.base_url, json=request.model_dump()) as resp:
payload = await resp.json() payload = await resp.json()
if resp.status != 200: if resp.status != 200:

View File

@ -1,4 +1,5 @@
# ruff: noqa: E402 # ruff: noqa: E402
from _pytest.fixtures import SubRequest
import requests import requests
@ -10,13 +11,13 @@ class SessionTimeoutFix(requests.Session):
requests.sessions.Session = SessionTimeoutFix requests.sessions.Session = SessionTimeoutFix
import warnings
import asyncio import asyncio
import contextlib import contextlib
import json import json
import math import math
import os import os
import random import random
import shutil
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
@ -72,6 +73,16 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(skip_release) item.add_marker(skip_release)
@pytest.fixture(autouse=True)
def container_log(request: SubRequest):
error_log = request.getfixturevalue("error_log")
assert error_log is not None
yield
if request.session.testsfailed:
error_log.seek(0)
print(error_log.read(), file=sys.stderr)
class ResponseComparator(JSONSnapshotExtension): class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2 rtol = 0.2
ignore_logprob = False ignore_logprob = False
@ -278,8 +289,10 @@ class IgnoreLogProbResponseComparator(ResponseComparator):
class LauncherHandle: class LauncherHandle:
def __init__(self, port: int): def __init__(self, port: int, error_log):
with warnings.catch_warnings(action="ignore"):
self.client = AsyncClient(f"http://localhost:{port}", timeout=30) self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
self.error_log = error_log
def _inner_health(self): def _inner_health(self):
raise NotImplementedError raise NotImplementedError
@ -288,6 +301,8 @@ class LauncherHandle:
assert timeout > 0 assert timeout > 0
for _ in range(timeout): for _ in range(timeout):
if not self._inner_health(): if not self._inner_health():
self.error_log.seek(0)
print(self.error_log.read(), file=sys.stderr)
raise RuntimeError("Launcher crashed") raise RuntimeError("Launcher crashed")
try: try:
@ -295,12 +310,14 @@ class LauncherHandle:
return return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError): except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
time.sleep(1) time.sleep(1)
self.error_log.seek(0)
print(self.error_log.read(), file=sys.stderr)
raise RuntimeError("Health check failed") raise RuntimeError("Health check failed")
class ContainerLauncherHandle(LauncherHandle): class ContainerLauncherHandle(LauncherHandle):
def __init__(self, docker_client, container_name, port: int): def __init__(self, docker_client, container_name, port: int, error_log):
super(ContainerLauncherHandle, self).__init__(port) super().__init__(port, error_log)
self.docker_client = docker_client self.docker_client = docker_client
self.container_name = container_name self.container_name = container_name
@ -310,8 +327,8 @@ class ContainerLauncherHandle(LauncherHandle):
class ProcessLauncherHandle(LauncherHandle): class ProcessLauncherHandle(LauncherHandle):
def __init__(self, process, port: int): def __init__(self, process, port: int, error_log):
super(ProcessLauncherHandle, self).__init__(port) super().__init__(port, error_log)
self.process = process self.process = process
def _inner_health(self) -> bool: def _inner_health(self) -> bool:
@ -334,14 +351,13 @@ def ignore_logprob_response_snapshot(snapshot):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def event_loop(): def error_log():
loop = asyncio.get_event_loop() with tempfile.TemporaryFile("w+") as tmp:
yield loop yield tmp
loop.close()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def launcher(event_loop): async def launcher(error_log):
@contextlib.contextmanager @contextlib.contextmanager
def local_launcher( def local_launcher(
model_id: str, model_id: str,
@ -429,23 +445,20 @@ def launcher(event_loop):
if attention is not None: if attention is not None:
env["ATTENTION"] = attention env["ATTENTION"] = attention
with tempfile.TemporaryFile("w+") as tmp: # with tempfile.TemporaryFile("w+") as tmp:
# We'll output stdout/stderr to a temporary file. Using a pipe # We'll output stdout/stderr to a temporary file. Using a pipe
# cause the process to block until stdout is read. # cause the process to block until stdout is read.
with subprocess.Popen( with subprocess.Popen(
args, args,
stdout=tmp, stdout=error_log,
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
env=env, env=env,
) as process: ) as process:
yield ProcessLauncherHandle(process, port) yield ProcessLauncherHandle(process, port, error_log=error_log)
process.terminate() process.terminate()
process.wait(60) process.wait(60)
tmp.seek(0)
shutil.copyfileobj(tmp, sys.stderr)
if not use_flash_attention: if not use_flash_attention:
del env["USE_FLASH_ATTENTION"] del env["USE_FLASH_ATTENTION"]
@ -578,8 +591,21 @@ def launcher(event_loop):
shm_size="1G", shm_size="1G",
) )
def pipe():
for log in container.logs(stream=True):
log = log.decode("utf-8")
error_log.write(log)
# Start looping to pipe the logs
import threading
t = threading.Thread(target=pipe, args=())
t.start()
try: try:
yield ContainerLauncherHandle(client, container.name, port) yield ContainerLauncherHandle(
client, container.name, port, error_log=error_log
)
if not use_flash_attention: if not use_flash_attention:
del env["USE_FLASH_ATTENTION"] del env["USE_FLASH_ATTENTION"]
@ -590,9 +616,6 @@ def launcher(event_loop):
except NotFound: except NotFound:
pass pass
container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)
finally: finally:
try: try:
container.remove() container.remove()