mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Having less logs in case of failure for checking CI more easily.
This commit is contained in:
parent
a7448661f7
commit
230b25165d
@ -867,7 +867,7 @@ class AsyncClient:
|
||||
async with ClientSession(
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
) as session:
|
||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||
async with session.post(self.base_url, json=request.model_dump()) as resp:
|
||||
payload = await resp.json()
|
||||
|
||||
if resp.status != 200:
|
||||
|
@ -1,4 +1,5 @@
|
||||
# ruff: noqa: E402
|
||||
from _pytest.fixtures import SubRequest
|
||||
import requests
|
||||
|
||||
|
||||
@ -10,13 +11,13 @@ class SessionTimeoutFix(requests.Session):
|
||||
|
||||
requests.sessions.Session = SessionTimeoutFix
|
||||
|
||||
import warnings
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@ -72,6 +73,16 @@ def pytest_collection_modifyitems(config, items):
|
||||
item.add_marker(skip_release)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def container_log(request: SubRequest):
|
||||
error_log = request.getfixturevalue("error_log")
|
||||
assert error_log is not None
|
||||
yield
|
||||
if request.session.testsfailed:
|
||||
error_log.seek(0)
|
||||
print(error_log.read(), file=sys.stderr)
|
||||
|
||||
|
||||
class ResponseComparator(JSONSnapshotExtension):
|
||||
rtol = 0.2
|
||||
ignore_logprob = False
|
||||
@ -278,8 +289,10 @@ class IgnoreLogProbResponseComparator(ResponseComparator):
|
||||
|
||||
|
||||
class LauncherHandle:
|
||||
def __init__(self, port: int):
|
||||
def __init__(self, port: int, error_log):
|
||||
with warnings.catch_warnings(action="ignore"):
|
||||
self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
|
||||
self.error_log = error_log
|
||||
|
||||
def _inner_health(self):
|
||||
raise NotImplementedError
|
||||
@ -288,6 +301,8 @@ class LauncherHandle:
|
||||
assert timeout > 0
|
||||
for _ in range(timeout):
|
||||
if not self._inner_health():
|
||||
self.error_log.seek(0)
|
||||
print(self.error_log.read(), file=sys.stderr)
|
||||
raise RuntimeError("Launcher crashed")
|
||||
|
||||
try:
|
||||
@ -295,12 +310,14 @@ class LauncherHandle:
|
||||
return
|
||||
except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
|
||||
time.sleep(1)
|
||||
self.error_log.seek(0)
|
||||
print(self.error_log.read(), file=sys.stderr)
|
||||
raise RuntimeError("Health check failed")
|
||||
|
||||
|
||||
class ContainerLauncherHandle(LauncherHandle):
|
||||
def __init__(self, docker_client, container_name, port: int):
|
||||
super(ContainerLauncherHandle, self).__init__(port)
|
||||
def __init__(self, docker_client, container_name, port: int, error_log):
|
||||
super().__init__(port, error_log)
|
||||
self.docker_client = docker_client
|
||||
self.container_name = container_name
|
||||
|
||||
@ -310,8 +327,8 @@ class ContainerLauncherHandle(LauncherHandle):
|
||||
|
||||
|
||||
class ProcessLauncherHandle(LauncherHandle):
|
||||
def __init__(self, process, port: int):
|
||||
super(ProcessLauncherHandle, self).__init__(port)
|
||||
def __init__(self, process, port: int, error_log):
|
||||
super().__init__(port, error_log)
|
||||
self.process = process
|
||||
|
||||
def _inner_health(self) -> bool:
|
||||
@ -334,14 +351,13 @@ def ignore_logprob_response_snapshot(snapshot):
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
loop = asyncio.get_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
def error_log():
|
||||
with tempfile.TemporaryFile("w+") as tmp:
|
||||
yield tmp
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def launcher(event_loop):
|
||||
async def launcher(error_log):
|
||||
@contextlib.contextmanager
|
||||
def local_launcher(
|
||||
model_id: str,
|
||||
@ -429,23 +445,20 @@ def launcher(event_loop):
|
||||
if attention is not None:
|
||||
env["ATTENTION"] = attention
|
||||
|
||||
with tempfile.TemporaryFile("w+") as tmp:
|
||||
# with tempfile.TemporaryFile("w+") as tmp:
|
||||
# We'll output stdout/stderr to a temporary file. Using a pipe
|
||||
# cause the process to block until stdout is read.
|
||||
with subprocess.Popen(
|
||||
args,
|
||||
stdout=tmp,
|
||||
stdout=error_log,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=env,
|
||||
) as process:
|
||||
yield ProcessLauncherHandle(process, port)
|
||||
yield ProcessLauncherHandle(process, port, error_log=error_log)
|
||||
|
||||
process.terminate()
|
||||
process.wait(60)
|
||||
|
||||
tmp.seek(0)
|
||||
shutil.copyfileobj(tmp, sys.stderr)
|
||||
|
||||
if not use_flash_attention:
|
||||
del env["USE_FLASH_ATTENTION"]
|
||||
|
||||
@ -578,8 +591,21 @@ def launcher(event_loop):
|
||||
shm_size="1G",
|
||||
)
|
||||
|
||||
def pipe():
|
||||
for log in container.logs(stream=True):
|
||||
log = log.decode("utf-8")
|
||||
error_log.write(log)
|
||||
|
||||
# Start looping to pipe the logs
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=pipe, args=())
|
||||
t.start()
|
||||
|
||||
try:
|
||||
yield ContainerLauncherHandle(client, container.name, port)
|
||||
yield ContainerLauncherHandle(
|
||||
client, container.name, port, error_log=error_log
|
||||
)
|
||||
|
||||
if not use_flash_attention:
|
||||
del env["USE_FLASH_ATTENTION"]
|
||||
@ -590,9 +616,6 @@ def launcher(event_loop):
|
||||
except NotFound:
|
||||
pass
|
||||
|
||||
container_output = container.logs().decode("utf-8")
|
||||
print(container_output, file=sys.stderr)
|
||||
|
||||
finally:
|
||||
try:
|
||||
container.remove()
|
||||
|
Loading…
Reference in New Issue
Block a user