Having less logs in case of failure for checking CI more easily.

This commit is contained in:
Nicolas Patry 2025-02-19 11:56:34 +01:00
parent a7448661f7
commit 230b25165d
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
2 changed files with 53 additions and 30 deletions

View File

@ -867,7 +867,7 @@ class AsyncClient:
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
async with session.post(self.base_url, json=request.model_dump()) as resp:
payload = await resp.json()
if resp.status != 200:

View File

@ -1,4 +1,5 @@
# ruff: noqa: E402
from _pytest.fixtures import SubRequest
import requests
@ -10,13 +11,13 @@ class SessionTimeoutFix(requests.Session):
requests.sessions.Session = SessionTimeoutFix
import warnings
import asyncio
import contextlib
import json
import math
import os
import random
import shutil
import subprocess
import sys
import tempfile
@ -72,6 +73,16 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(skip_release)
@pytest.fixture(autouse=True)
def container_log(request: SubRequest):
error_log = request.getfixturevalue("error_log")
assert error_log is not None
yield
if request.session.testsfailed:
error_log.seek(0)
print(error_log.read(), file=sys.stderr)
class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2
ignore_logprob = False
@ -278,8 +289,10 @@ class IgnoreLogProbResponseComparator(ResponseComparator):
class LauncherHandle:
def __init__(self, port: int):
def __init__(self, port: int, error_log):
with warnings.catch_warnings(action="ignore"):
self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
self.error_log = error_log
def _inner_health(self):
raise NotImplementedError
@ -288,6 +301,8 @@ class LauncherHandle:
assert timeout > 0
for _ in range(timeout):
if not self._inner_health():
self.error_log.seek(0)
print(self.error_log.read(), file=sys.stderr)
raise RuntimeError("Launcher crashed")
try:
@ -295,12 +310,14 @@ class LauncherHandle:
return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
time.sleep(1)
self.error_log.seek(0)
print(self.error_log.read(), file=sys.stderr)
raise RuntimeError("Health check failed")
class ContainerLauncherHandle(LauncherHandle):
def __init__(self, docker_client, container_name, port: int):
super(ContainerLauncherHandle, self).__init__(port)
def __init__(self, docker_client, container_name, port: int, error_log):
super().__init__(port, error_log)
self.docker_client = docker_client
self.container_name = container_name
@ -310,8 +327,8 @@ class ContainerLauncherHandle(LauncherHandle):
class ProcessLauncherHandle(LauncherHandle):
def __init__(self, process, port: int):
super(ProcessLauncherHandle, self).__init__(port)
def __init__(self, process, port: int, error_log):
super().__init__(port, error_log)
self.process = process
def _inner_health(self) -> bool:
@ -334,14 +351,13 @@ def ignore_logprob_response_snapshot(snapshot):
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
def error_log():
with tempfile.TemporaryFile("w+") as tmp:
yield tmp
@pytest.fixture(scope="module")
def launcher(event_loop):
async def launcher(error_log):
@contextlib.contextmanager
def local_launcher(
model_id: str,
@ -429,23 +445,20 @@ def launcher(event_loop):
if attention is not None:
env["ATTENTION"] = attention
with tempfile.TemporaryFile("w+") as tmp:
# with tempfile.TemporaryFile("w+") as tmp:
# We'll output stdout/stderr to a temporary file. Using a pipe
# cause the process to block until stdout is read.
with subprocess.Popen(
args,
stdout=tmp,
stdout=error_log,
stderr=subprocess.STDOUT,
env=env,
) as process:
yield ProcessLauncherHandle(process, port)
yield ProcessLauncherHandle(process, port, error_log=error_log)
process.terminate()
process.wait(60)
tmp.seek(0)
shutil.copyfileobj(tmp, sys.stderr)
if not use_flash_attention:
del env["USE_FLASH_ATTENTION"]
@ -578,8 +591,21 @@ def launcher(event_loop):
shm_size="1G",
)
def pipe():
for log in container.logs(stream=True):
log = log.decode("utf-8")
error_log.write(log)
# Start looping to pipe the logs
import threading
t = threading.Thread(target=pipe, args=())
t.start()
try:
yield ContainerLauncherHandle(client, container.name, port)
yield ContainerLauncherHandle(
client, container.name, port, error_log=error_log
)
if not use_flash_attention:
del env["USE_FLASH_ATTENTION"]
@ -590,9 +616,6 @@ def launcher(event_loop):
except NotFound:
pass
container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)
finally:
try:
container.remove()