mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
* feat(gaudi): add integration test * feat(test): add more models to integration tests * remove debug comments * fix typos
293 lines
9.7 KiB
Python
293 lines
9.7 KiB
Python
import asyncio
|
|
import contextlib
|
|
import os
|
|
import shlex
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
from tempfile import TemporaryDirectory
|
|
from typing import List
|
|
import socket
|
|
|
|
import docker
|
|
import pytest
|
|
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
|
from docker.errors import NotFound
|
|
from loguru import logger
|
|
from test_model import TEST_CONFIGS
|
|
from text_generation import AsyncClient
|
|
from text_generation.types import Response
|
|
|
|
# Use the latest image from the local docker build
|
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tgi-gaudi")
|
|
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None)
|
|
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
|
|
|
assert (
|
|
HF_TOKEN is not None
|
|
), "HF_TOKEN is not set, please set it as some models are gated and thus the test will fail without it"
|
|
|
|
if DOCKER_VOLUME is None:
|
|
logger.warning(
|
|
"DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing"
|
|
)
|
|
|
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "info")
|
|
|
|
BASE_ENV = {
|
|
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
"LOG_LEVEL": LOG_LEVEL,
|
|
"HF_TOKEN": os.getenv("HF_TOKEN", None),
|
|
}
|
|
|
|
|
|
HABANA_RUN_ARGS = {
|
|
"runtime": "habana",
|
|
"ipc_mode": "host",
|
|
"cap_add": ["sys_nice"],
|
|
}
|
|
|
|
logger.add(
|
|
sys.stderr,
|
|
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
|
level="INFO",
|
|
)
|
|
|
|
|
|
def stream_container_logs(container, test_name):
|
|
"""Stream container logs in a separate thread."""
|
|
try:
|
|
for log in container.logs(stream=True, follow=True):
|
|
print(
|
|
f"[TGI Server Logs - {test_name}] {log.decode('utf-8')}",
|
|
end="",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error streaming container logs: {str(e)}")
|
|
|
|
|
|
class LauncherHandle:
|
|
def __init__(self, port: int):
|
|
self.client = AsyncClient(f"http://localhost:{port}", timeout=3600)
|
|
|
|
def _inner_health(self):
|
|
raise NotImplementedError
|
|
|
|
async def health(self, timeout: int = 60):
|
|
assert timeout > 0
|
|
start_time = time.time()
|
|
logger.info(f"Starting health check with timeout of {timeout}s")
|
|
|
|
for attempt in range(timeout):
|
|
if not self._inner_health():
|
|
logger.error("Launcher crashed during health check")
|
|
raise RuntimeError("Launcher crashed")
|
|
|
|
try:
|
|
await self.client.generate("test")
|
|
elapsed = time.time() - start_time
|
|
logger.info(f"Health check passed after {elapsed:.1f}s")
|
|
return
|
|
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
|
|
if attempt == timeout - 1:
|
|
logger.error(f"Health check failed after {timeout}s: {str(e)}")
|
|
raise RuntimeError(f"Health check failed: {str(e)}")
|
|
if attempt % 10 == 0 and attempt != 0: # Only log every 10th attempt
|
|
logger.debug(
|
|
f"Connection attempt {attempt}/{timeout} failed: {str(e)}"
|
|
)
|
|
time.sleep(1)
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error during health check: {str(e)}")
|
|
# Get full traceback for debugging
|
|
import traceback
|
|
|
|
logger.error(f"Full traceback:\n{traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
class ContainerLauncherHandle(LauncherHandle):
|
|
def __init__(self, docker_client, container_name, port: int):
|
|
super(ContainerLauncherHandle, self).__init__(port)
|
|
self.docker_client = docker_client
|
|
self.container_name = container_name
|
|
|
|
def _inner_health(self) -> bool:
|
|
try:
|
|
container = self.docker_client.containers.get(self.container_name)
|
|
status = container.status
|
|
if status not in ["running", "created"]:
|
|
logger.warning(f"Container status is {status}")
|
|
# Get container logs for debugging
|
|
logs = container.logs().decode("utf-8")
|
|
logger.debug(f"Container logs:\n{logs}")
|
|
return status in ["running", "created"]
|
|
except Exception as e:
|
|
logger.error(f"Error checking container health: {str(e)}")
|
|
return False
|
|
|
|
|
|
class ProcessLauncherHandle(LauncherHandle):
|
|
def __init__(self, process, port: int):
|
|
super(ProcessLauncherHandle, self).__init__(port)
|
|
self.process = process
|
|
|
|
def _inner_health(self) -> bool:
|
|
return self.process.poll() is None
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def data_volume():
|
|
tmpdir = TemporaryDirectory()
|
|
yield tmpdir.name
|
|
try:
|
|
# Cleanup the temporary directory using sudo as it contains root files created by the container
|
|
subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"), check=True)
|
|
except subprocess.CalledProcessError as e:
|
|
logger.error(f"Error cleaning up temporary directory: {str(e)}")
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def launcher(data_volume):
|
|
@contextlib.contextmanager
|
|
def docker_launcher(
|
|
model_id: str,
|
|
test_name: str,
|
|
):
|
|
logger.info(
|
|
f"Starting docker launcher for model {model_id} and test {test_name}"
|
|
)
|
|
|
|
# Get a random available port
|
|
def get_free_port():
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("", 0))
|
|
s.listen(1)
|
|
port = s.getsockname()[1]
|
|
return port
|
|
|
|
port = get_free_port()
|
|
logger.debug(f"Using port {port}")
|
|
|
|
client = docker.from_env()
|
|
|
|
container_name = f"tgi-gaudi-test-{test_name.replace('/', '-')}"
|
|
|
|
try:
|
|
container = client.containers.get(container_name)
|
|
logger.info(
|
|
f"Stopping existing container {container_name} for test {test_name}"
|
|
)
|
|
container.stop()
|
|
container.wait()
|
|
except NotFound:
|
|
pass
|
|
except Exception as e:
|
|
logger.error(f"Error handling existing container: {str(e)}")
|
|
|
|
model_name = next(
|
|
name for name, cfg in TEST_CONFIGS.items() if cfg["model_id"] == model_id
|
|
)
|
|
|
|
tgi_args = TEST_CONFIGS[model_name]["args"].copy()
|
|
|
|
env = BASE_ENV.copy()
|
|
|
|
# Add model_id to env
|
|
env["MODEL_ID"] = model_id
|
|
|
|
# Add env config that is definied in the fixture parameter
|
|
if "env_config" in TEST_CONFIGS[model_name]:
|
|
env.update(TEST_CONFIGS[model_name]["env_config"].copy())
|
|
|
|
volumes = [f"{DOCKER_VOLUME}:/data"]
|
|
logger.debug(f"Using volume {volumes}")
|
|
|
|
try:
|
|
logger.info(f"Creating container with name {container_name}")
|
|
|
|
# Log equivalent docker run command for debugging, this is not actually executed
|
|
container = client.containers.run(
|
|
DOCKER_IMAGE,
|
|
command=tgi_args,
|
|
name=container_name,
|
|
environment=env,
|
|
detach=True,
|
|
volumes=volumes,
|
|
ports={"80/tcp": port},
|
|
**HABANA_RUN_ARGS,
|
|
)
|
|
|
|
logger.info(f"Container {container_name} started successfully")
|
|
|
|
# Start log streaming in a background thread
|
|
log_thread = threading.Thread(
|
|
target=stream_container_logs,
|
|
args=(container, test_name),
|
|
daemon=True, # This ensures the thread will be killed when the main program exits
|
|
)
|
|
log_thread.start()
|
|
|
|
# Add a small delay to allow container to initialize
|
|
time.sleep(2)
|
|
|
|
# Check container status after creation
|
|
status = container.status
|
|
logger.debug(f"Initial container status: {status}")
|
|
if status not in ["running", "created"]:
|
|
logs = container.logs().decode("utf-8")
|
|
logger.error(f"Container failed to start properly. Logs:\n{logs}")
|
|
|
|
yield ContainerLauncherHandle(client, container.name, port)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error starting container: {str(e)}")
|
|
# Get full traceback for debugging
|
|
import traceback
|
|
|
|
logger.error(f"Full traceback:\n{traceback.format_exc()}")
|
|
raise
|
|
finally:
|
|
try:
|
|
container = client.containers.get(container_name)
|
|
logger.info(f"Stopping container {container_name}")
|
|
container.stop()
|
|
container.wait()
|
|
|
|
container_output = container.logs().decode("utf-8")
|
|
print(container_output, file=sys.stderr)
|
|
|
|
container.remove()
|
|
logger.info(f"Container {container_name} removed successfully")
|
|
except NotFound:
|
|
pass
|
|
except Exception as e:
|
|
logger.warning(f"Error cleaning up container: {str(e)}")
|
|
|
|
return docker_launcher
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def generate_load():
|
|
async def generate_load_inner(
|
|
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
|
|
) -> List[Response]:
|
|
try:
|
|
futures = [
|
|
client.generate(
|
|
prompt,
|
|
max_new_tokens=max_new_tokens,
|
|
decoder_input_details=True,
|
|
)
|
|
for _ in range(n)
|
|
]
|
|
return await asyncio.gather(*futures)
|
|
except Exception as e:
|
|
logger.error(f"Error generating load: {str(e)}")
|
|
raise
|
|
|
|
return generate_load_inner
|