text-generation-inference/load_tests/benchmarks/engine.py
2024-07-12 23:49:17 +02:00

152 lines
5.3 KiB
Python

import subprocess
import threading
from typing import Dict, List
import docker
from docker.models.containers import Container
from loguru import logger
from benchmarks.utils import kill
class InferenceEngineRunner:
def __init__(self, model: str):
self.model = model
def run(self, parameters: list[tuple]):
NotImplementedError("This method should be implemented by the subclass")
def stop(self):
NotImplementedError("This method should be implemented by the subclass")
class TGIRunner(InferenceEngineRunner):
def __init__(self, model: str):
super().__init__(model)
self.process = None
self.model = model
def run(self, parameters: list[tuple]):
params = ""
for p in parameters:
params += f"--{p[0]} {str(p[1])}"
# start a TGI subprocess with the given parameter
args = f"text-generation-launcher --port 8080 --model-id {self.model} --huggingface-hub-cache /scratch {params}"
logger.info(f"Running TGI with parameters: {args}")
self.process = subprocess.Popen(args,
stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
for line in iter(self.process.stdout.readline, b""):
print(line.decode("utf-8"))
# wait for TGI to listen to the port
if b"Connected" in line:
break
if b"Error" in line:
raise Exception(f"Error starting TGI: {line}")
# continue to stream the logs in a thread
def stream_logs():
for line in iter(self.process.stdout.readline, b""):
print(line.decode("utf-8"))
if self.process.returncode is not None:
raise Exception("Error starting TGI")
self.thread = threading.Thread(target=stream_logs)
self.thread.start()
def stop(self):
logger.warning(f"Killing TGI with PID {self.process.pid}")
if self.process:
kill(self.process.pid)
if self.thread:
self.thread.join()
class TGIDockerRunner(InferenceEngineRunner):
def __init__(self,
model: str,
image: str = "ghcr.io/huggingface/text-generation-inference:latest",
volumes=None):
super().__init__(model)
if volumes is None:
volumes = []
self.container = None
self.image = image
self.volumes = volumes
def run(self, parameters: list[tuple]):
params = f"--model-id {self.model} --port 8080"
for p in parameters:
params += f" --{p[0]} {str(p[1])}"
logger.info(f"Running TGI with parameters: {params}")
volumes = {}
for v in self.volumes:
volumes[v[0]] = {"bind": v[1], "mode": "rw"}
self.container = run_docker(self.image, params,
"Connected",
"Error",
volumes=volumes)
def stop(self):
if self.container:
self.container.stop()
class VLLMDockerRunner(InferenceEngineRunner):
def __init__(self,
model: str,
image: str = "vllm/vllm-openai:latest",
volumes=None):
super().__init__(model)
if volumes is None:
volumes = []
self.container = None
self.image = image
self.volumes = volumes
def run(self, parameters: list[tuple]):
parameters.append(("max-num-seqs", "256"))
params = f"--model {self.model} --tensor-parallel-size {get_num_gpus()} --port 8080"
for p in parameters:
params += f" --{p[0]} {str(p[1])}"
logger.info(f"Running VLLM with parameters: {params}")
volumes = {}
for v in self.volumes:
volumes[v[0]] = {"bind": v[1], "mode": "rw"}
self.container = run_docker(self.image, params, "Uvicorn running",
"Error ",
volumes=volumes)
def stop(self):
if self.container:
self.container.stop()
def run_docker(image: str, args: str, success_sentinel: str,
error_sentinel: str, volumes=None) -> Container:
if volumes is None:
volumes = {}
client = docker.from_env()
# retrieve the GPU devices from CUDA_VISIBLE_DEVICES
devices = [f"{i}" for i in
range(get_num_gpus())]
container = client.containers.run(image, args,
detach=True,
device_requests=[
docker.types.DeviceRequest(device_ids=devices, capabilities=[['gpu']])
],
volumes=volumes,
shm_size="1g",
ports={"8080/tcp": 8080})
for line in container.logs(stream=True):
print(line.decode("utf-8"), end="")
if success_sentinel.encode("utf-8") in line:
break
if error_sentinel.encode("utf-8") in line:
container.stop()
raise Exception(f"Error starting container: {line}")
return container
def get_num_gpus() -> int:
return len(subprocess.run(["nvidia-smi", "-L"], capture_output=True).stdout.splitlines())