import asyncio import contextlib import os import shlex import subprocess import sys import threading import time from tempfile import TemporaryDirectory from typing import List import socket import docker import pytest from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from docker.errors import NotFound from loguru import logger from test_model import TEST_CONFIGS from text_generation import AsyncClient from text_generation.types import Response # Use the latest image from the local docker build DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tgi-gaudi") DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None) HF_TOKEN = os.getenv("HF_TOKEN", None) assert ( HF_TOKEN is not None ), "HF_TOKEN is not set, please set it as some models are gated and thus the test will fail without it" if DOCKER_VOLUME is None: logger.warning( "DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing" ) LOG_LEVEL = os.getenv("LOG_LEVEL", "info") BASE_ENV = { "HF_HUB_ENABLE_HF_TRANSFER": "1", "LOG_LEVEL": LOG_LEVEL, "HF_TOKEN": os.getenv("HF_TOKEN", None), } HABANA_RUN_ARGS = { "runtime": "habana", "ipc_mode": "host", "cap_add": ["sys_nice"], } logger.add( sys.stderr, format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}", level="INFO", ) def stream_container_logs(container, test_name): """Stream container logs in a separate thread.""" try: for log in container.logs(stream=True, follow=True): print( f"[TGI Server Logs - {test_name}] {log.decode('utf-8')}", end="", file=sys.stderr, flush=True, ) except Exception as e: logger.error(f"Error streaming container logs: {str(e)}") class LauncherHandle: def __init__(self, port: int): self.client = AsyncClient(f"http://localhost:{port}", timeout=3600) def _inner_health(self): raise NotImplementedError async def health(self, timeout: int = 60): assert timeout > 0 start_time = time.time() logger.info(f"Starting health check with timeout of {timeout}s") for attempt in range(timeout): if not self._inner_health(): logger.error("Launcher crashed during health check") raise RuntimeError("Launcher crashed") try: await self.client.generate("test") elapsed = time.time() - start_time logger.info(f"Health check passed after {elapsed:.1f}s") return except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e: if attempt == timeout - 1: logger.error(f"Health check failed after {timeout}s: {str(e)}") raise RuntimeError(f"Health check failed: {str(e)}") if attempt % 10 == 0 and attempt != 0: # Only log every 10th attempt logger.debug( f"Connection attempt {attempt}/{timeout} failed: {str(e)}" ) time.sleep(1) except Exception as e: logger.error(f"Unexpected error during health check: {str(e)}") # Get full traceback for debugging import traceback logger.error(f"Full traceback:\n{traceback.format_exc()}") raise class ContainerLauncherHandle(LauncherHandle): def __init__(self, docker_client, container_name, port: int): super(ContainerLauncherHandle, self).__init__(port) self.docker_client = docker_client self.container_name = container_name def _inner_health(self) -> bool: try: container = self.docker_client.containers.get(self.container_name) status = container.status if status not in ["running", "created"]: logger.warning(f"Container status is {status}") # Get container logs for debugging logs = container.logs().decode("utf-8") logger.debug(f"Container logs:\n{logs}") return status in ["running", "created"] except Exception as e: logger.error(f"Error checking container health: {str(e)}") return False class ProcessLauncherHandle(LauncherHandle): def __init__(self, process, port: int): super(ProcessLauncherHandle, self).__init__(port) self.process = process def _inner_health(self) -> bool: return self.process.poll() is None @pytest.fixture(scope="module") def data_volume(): tmpdir = TemporaryDirectory() yield tmpdir.name try: # Cleanup the temporary directory using sudo as it contains root files created by the container subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"), check=True) except subprocess.CalledProcessError as e: logger.error(f"Error cleaning up temporary directory: {str(e)}") @pytest.fixture(scope="module") def launcher(data_volume): @contextlib.contextmanager def docker_launcher( model_id: str, test_name: str, ): logger.info( f"Starting docker launcher for model {model_id} and test {test_name}" ) # Get a random available port def get_free_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) s.listen(1) port = s.getsockname()[1] return port port = get_free_port() logger.debug(f"Using port {port}") client = docker.from_env() container_name = f"tgi-gaudi-test-{test_name.replace('/', '-')}" try: container = client.containers.get(container_name) logger.info( f"Stopping existing container {container_name} for test {test_name}" ) container.stop() container.wait() except NotFound: pass except Exception as e: logger.error(f"Error handling existing container: {str(e)}") model_name = next( name for name, cfg in TEST_CONFIGS.items() if cfg["model_id"] == model_id ) tgi_args = TEST_CONFIGS[model_name]["args"].copy() env = BASE_ENV.copy() # Add model_id to env env["MODEL_ID"] = model_id # Add env config that is definied in the fixture parameter if "env_config" in TEST_CONFIGS[model_name]: env.update(TEST_CONFIGS[model_name]["env_config"].copy()) volumes = [f"{DOCKER_VOLUME}:/data"] logger.debug(f"Using volume {volumes}") try: logger.info(f"Creating container with name {container_name}") # Log equivalent docker run command for debugging, this is not actually executed container = client.containers.run( DOCKER_IMAGE, command=tgi_args, name=container_name, environment=env, detach=True, volumes=volumes, ports={"80/tcp": port}, **HABANA_RUN_ARGS, ) logger.info(f"Container {container_name} started successfully") # Start log streaming in a background thread log_thread = threading.Thread( target=stream_container_logs, args=(container, test_name), daemon=True, # This ensures the thread will be killed when the main program exits ) log_thread.start() # Add a small delay to allow container to initialize time.sleep(2) # Check container status after creation status = container.status logger.debug(f"Initial container status: {status}") if status not in ["running", "created"]: logs = container.logs().decode("utf-8") logger.error(f"Container failed to start properly. Logs:\n{logs}") yield ContainerLauncherHandle(client, container.name, port) except Exception as e: logger.error(f"Error starting container: {str(e)}") # Get full traceback for debugging import traceback logger.error(f"Full traceback:\n{traceback.format_exc()}") raise finally: try: container = client.containers.get(container_name) logger.info(f"Stopping container {container_name}") container.stop() container.wait() container_output = container.logs().decode("utf-8") print(container_output, file=sys.stderr) container.remove() logger.info(f"Container {container_name} removed successfully") except NotFound: pass except Exception as e: logger.warning(f"Error cleaning up container: {str(e)}") return docker_launcher @pytest.fixture(scope="module") def generate_load(): async def generate_load_inner( client: AsyncClient, prompt: str, max_new_tokens: int, n: int ) -> List[Response]: try: futures = [ client.generate( prompt, max_new_tokens=max_new_tokens, decoder_input_details=True, ) for _ in range(n) ] return await asyncio.gather(*futures) except Exception as e: logger.error(f"Error generating load: {str(e)}") raise return generate_load_inner