diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 45301b63..0b60d93a 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -867,7 +867,7 @@ class AsyncClient: async with ClientSession( headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: - async with session.post(self.base_url, json=request.dict()) as resp: + async with session.post(self.base_url, json=request.model_dump()) as resp: payload = await resp.json() if resp.status != 200: diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 6f8aa715..2d3ae8a2 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -1,4 +1,5 @@ # ruff: noqa: E402 +from _pytest.fixtures import SubRequest import requests @@ -10,13 +11,13 @@ class SessionTimeoutFix(requests.Session): requests.sessions.Session = SessionTimeoutFix +import warnings import asyncio import contextlib import json import math import os import random -import shutil import subprocess import sys import tempfile @@ -72,6 +73,16 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_release) +@pytest.fixture(autouse=True) +def container_log(request: SubRequest): + error_log = request.getfixturevalue("error_log") + assert error_log is not None + yield + if request.session.testsfailed: + error_log.seek(0) + print(error_log.read(), file=sys.stderr) + + class ResponseComparator(JSONSnapshotExtension): rtol = 0.2 ignore_logprob = False @@ -278,8 +289,10 @@ class IgnoreLogProbResponseComparator(ResponseComparator): class LauncherHandle: - def __init__(self, port: int): - self.client = AsyncClient(f"http://localhost:{port}", timeout=30) + def __init__(self, port: int, error_log): + with warnings.catch_warnings(action="ignore"): + self.client = AsyncClient(f"http://localhost:{port}", timeout=30) + self.error_log = error_log def _inner_health(self): raise NotImplementedError @@ -288,6 +301,8 @@ class LauncherHandle: assert timeout > 0 for _ in range(timeout): if not self._inner_health(): + self.error_log.seek(0) + print(self.error_log.read(), file=sys.stderr) raise RuntimeError("Launcher crashed") try: @@ -295,12 +310,14 @@ class LauncherHandle: return except (ClientConnectorError, ClientOSError, ServerDisconnectedError): time.sleep(1) + self.error_log.seek(0) + print(self.error_log.read(), file=sys.stderr) raise RuntimeError("Health check failed") class ContainerLauncherHandle(LauncherHandle): - def __init__(self, docker_client, container_name, port: int): - super(ContainerLauncherHandle, self).__init__(port) + def __init__(self, docker_client, container_name, port: int, error_log): + super().__init__(port, error_log) self.docker_client = docker_client self.container_name = container_name @@ -310,8 +327,8 @@ class ContainerLauncherHandle(LauncherHandle): class ProcessLauncherHandle(LauncherHandle): - def __init__(self, process, port: int): - super(ProcessLauncherHandle, self).__init__(port) + def __init__(self, process, port: int, error_log): + super().__init__(port, error_log) self.process = process def _inner_health(self) -> bool: @@ -334,14 +351,13 @@ def ignore_logprob_response_snapshot(snapshot): @pytest.fixture(scope="module") -def event_loop(): - loop = asyncio.get_event_loop() - yield loop - loop.close() +def error_log(): + with tempfile.TemporaryFile("w+") as tmp: + yield tmp @pytest.fixture(scope="module") -def launcher(event_loop): +async def launcher(error_log): @contextlib.contextmanager def local_launcher( model_id: str, @@ -429,22 +445,19 @@ def launcher(event_loop): if attention is not None: env["ATTENTION"] = attention - with tempfile.TemporaryFile("w+") as tmp: + # with tempfile.TemporaryFile("w+") as tmp: # We'll output stdout/stderr to a temporary file. Using a pipe # cause the process to block until stdout is read. - with subprocess.Popen( - args, - stdout=tmp, - stderr=subprocess.STDOUT, - env=env, - ) as process: - yield ProcessLauncherHandle(process, port) + with subprocess.Popen( + args, + stdout=error_log, + stderr=subprocess.STDOUT, + env=env, + ) as process: + yield ProcessLauncherHandle(process, port, error_log=error_log) - process.terminate() - process.wait(60) - - tmp.seek(0) - shutil.copyfileobj(tmp, sys.stderr) + process.terminate() + process.wait(60) if not use_flash_attention: del env["USE_FLASH_ATTENTION"] @@ -578,8 +591,21 @@ def launcher(event_loop): shm_size="1G", ) + def pipe(): + for log in container.logs(stream=True): + log = log.decode("utf-8") + error_log.write(log) + + # Start looping to pipe the logs + import threading + + t = threading.Thread(target=pipe, args=()) + t.start() + try: - yield ContainerLauncherHandle(client, container.name, port) + yield ContainerLauncherHandle( + client, container.name, port, error_log=error_log + ) if not use_flash_attention: del env["USE_FLASH_ATTENTION"] @@ -590,9 +616,6 @@ def launcher(event_loop): except NotFound: pass - container_output = container.logs().decode("utf-8") - print(container_output, file=sys.stderr) - finally: try: container.remove()