mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
* feat: add neuron backend * feat(neuron): add server standalone installation * feat(neuron): add server and integration tests * fix(neuron): increase ulimit when building image The base image used to compile the rust components seems to have a low ulimit for opened files, which leads to errors during compilation. * test(neuron): merge integration tests and fixtures * test: add --neuron option * review: do not use latest tag * review: remove ureq pinned version * review: --privileged should be the exception * feat: add neuron case to build ci * fix(neuron): export models from container in test fixtures The neuron tests require models to have been previously exported and cached on the hub. This is done automatically by the neuron.model fixture the first time the tests are ran for a specific version. This fixture used to export the models using optimum-neuron directly, but this package is not necessarily present on the system. Instead, it is now done through the neuron TGI itself, since it contains all the tools required to export the models. Note that since the CI runs docker in docker (dind) it does not seem possible to share a volume between the CI container and the container used to export the model. For that reason, a specific image with a modified entrypoint is built on-the-fly when a model export is required. * refactor: remove sagemaker entry-point The SageMaker image is built differently anyway. * fix(neuron): avoid using Levenshtein * test(neuron): use smaller llama model * feat(neuron): avoid installing CUDA in image * test(neuron): no error anymore when requesting too many tokens * ci: doing a precompilation step (with a different token). * test(neuron): avoid using image sha when exporting models We now manually evaluate the apparent hash of the neuron backend by combining the hash of the neuron backend directory and Dockerfile. This new hash is used to identify exported neuron models instead of the image sha. This has two benefits: - it changes less frequently (only hwen the neuron backend changes), which means less neuron models being pushed to the hub, - it can be evaluated locally, meaning that running the tests once locally will export the models before the CI uses them. * test(neuron): added a small script to prune test models --------- Co-authored-by: drbh <david.richard.holtz@gmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
224 lines
8.6 KiB
Python
224 lines
8.6 KiB
Python
import copy
|
|
import logging
|
|
import sys
|
|
from tempfile import TemporaryDirectory
|
|
|
|
import huggingface_hub
|
|
import pytest
|
|
import docker
|
|
import hashlib
|
|
import os
|
|
import tempfile
|
|
|
|
from docker.errors import NotFound
|
|
|
|
|
|
TEST_ORGANIZATION = "optimum-internal-testing"
|
|
TEST_CACHE_REPO_ID = f"{TEST_ORGANIZATION}/neuron-testing-cache"
|
|
HF_TOKEN = huggingface_hub.get_token()
|
|
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="[%(asctime)s] %(levelname)s [%(filename)s.%(funcName)s:%(lineno)d] %(message)s",
|
|
stream=sys.stdout,
|
|
)
|
|
logger = logging.getLogger(__file__)
|
|
|
|
|
|
# All model configurations below will be added to the neuron_model_config fixture
|
|
MODEL_CONFIGURATIONS = {
|
|
"gpt2": {
|
|
"model_id": "gpt2",
|
|
"export_kwargs": {"batch_size": 4, "sequence_length": 1024, "num_cores": 2, "auto_cast_type": "fp16"},
|
|
},
|
|
"llama": {
|
|
"model_id": "unsloth/Llama-3.2-1B-Instruct",
|
|
"export_kwargs": {"batch_size": 4, "sequence_length": 2048, "num_cores": 2, "auto_cast_type": "fp16"},
|
|
},
|
|
"mistral": {
|
|
"model_id": "optimum/mistral-1.1b-testing",
|
|
"export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"},
|
|
},
|
|
"qwen2": {
|
|
"model_id": "Qwen/Qwen2.5-0.5B",
|
|
"export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "fp16"},
|
|
},
|
|
"granite": {
|
|
"model_id": "ibm-granite/granite-3.1-2b-instruct",
|
|
"export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"},
|
|
},
|
|
}
|
|
|
|
|
|
def get_neuron_backend_hash():
|
|
import subprocess
|
|
res = subprocess.run(["git", "rev-parse", "--show-toplevel"],
|
|
capture_output=True,
|
|
text=True)
|
|
root_dir = res.stdout.split('\n')[0]
|
|
def get_sha(path):
|
|
res = subprocess.run(["git", "ls-tree", "HEAD", f"{root_dir}/{path}"],
|
|
capture_output=True,
|
|
text=True)
|
|
# Output of the command is in the form '040000 tree|blob <SHA>\t<path>\n'
|
|
sha = res.stdout.split('\t')[0].split(' ')[-1]
|
|
return sha.encode()
|
|
# We hash both the neuron backends directory and Dockerfile and create a smaller hash out of that
|
|
m = hashlib.sha256()
|
|
m.update(get_sha('backends/neuron'))
|
|
m.update(get_sha('Dockerfile.neuron'))
|
|
return m.hexdigest()[:10]
|
|
|
|
|
|
def get_neuron_model_name(config_name: str):
|
|
return f"neuron-tgi-testing-{config_name}-{get_neuron_backend_hash()}"
|
|
|
|
|
|
def get_tgi_docker_image():
|
|
docker_image = os.getenv("DOCKER_IMAGE", None)
|
|
if docker_image is None:
|
|
client = docker.from_env()
|
|
images = client.images.list(filters={"reference": "text-generation-inference"})
|
|
if not images:
|
|
raise ValueError("No text-generation-inference image found on this host to run tests.")
|
|
docker_image = images[0].tags[0]
|
|
return docker_image
|
|
|
|
|
|
def export_model(config_name, model_config, neuron_model_name):
|
|
"""Export a neuron model.
|
|
|
|
The model is exported by a custom image built on the fly from the base TGI image.
|
|
This makes sure the exported model and image are aligned and avoids introducing
|
|
neuron specific imports in the test suite.
|
|
|
|
Args:
|
|
config_name (`str`):
|
|
Used to identify test configurations
|
|
model_config (`str`):
|
|
The model configuration for export (includes the original model id)
|
|
neuron_model_name (`str`):
|
|
The name of the exported model on the hub
|
|
"""
|
|
|
|
client = docker.from_env()
|
|
|
|
env = {"LOG_LEVEL": "info", "CUSTOM_CACHE_REPO": TEST_CACHE_REPO_ID}
|
|
if HF_TOKEN is not None:
|
|
env["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
|
|
env["HF_TOKEN"] = HF_TOKEN
|
|
|
|
# Create a sub-image to export the model to workaround docker dind issues preventing
|
|
# to share a volume from the container running tests
|
|
model_id = model_config["model_id"]
|
|
export_kwargs = model_config["export_kwargs"]
|
|
base_image = get_tgi_docker_image()
|
|
export_image = f"neuron-tgi-tests-{config_name}-export-img"
|
|
logger.info(f"Building temporary image {export_image} from {base_image}")
|
|
with tempfile.TemporaryDirectory() as context_dir:
|
|
# Create entrypoint
|
|
model_path = "/data/neuron_model"
|
|
export_command = f"optimum-cli export neuron -m {model_id} --task text-generation"
|
|
for kwarg, value in export_kwargs.items():
|
|
export_command += f" --{kwarg} {str(value)}"
|
|
export_command += f" {model_path}"
|
|
entrypoint_content = f"""#!/bin/sh
|
|
{export_command}
|
|
huggingface-cli repo create --organization {TEST_ORGANIZATION} {neuron_model_name}
|
|
huggingface-cli upload {TEST_ORGANIZATION}/{neuron_model_name} {model_path} --exclude *.bin *.safetensors
|
|
optimum-cli neuron cache synchronize --repo_id {TEST_CACHE_REPO_ID}
|
|
"""
|
|
with open(os.path.join(context_dir, "entrypoint.sh"), "wb") as f:
|
|
f.write(entrypoint_content.encode("utf-8"))
|
|
f.flush()
|
|
# Create Dockerfile
|
|
docker_content = f"""
|
|
FROM {base_image}
|
|
COPY entrypoint.sh /export-entrypoint.sh
|
|
RUN chmod +x /export-entrypoint.sh
|
|
ENTRYPOINT ["/export-entrypoint.sh"]
|
|
"""
|
|
with open(os.path.join(context_dir, "Dockerfile"), "wb") as f:
|
|
f.write(docker_content.encode("utf-8"))
|
|
f.flush()
|
|
image, logs = client.images.build(path=context_dir, dockerfile=f.name, tag=export_image)
|
|
logger.info("Successfully built image %s", image.id)
|
|
logger.debug("Build logs %s", logs)
|
|
|
|
try:
|
|
container = client.containers.run(
|
|
export_image,
|
|
environment=env,
|
|
auto_remove=True,
|
|
detach=False,
|
|
devices=["/dev/neuron0"],
|
|
shm_size="1G",
|
|
)
|
|
logger.info(f"Successfully exported model for config {config_name}")
|
|
container.logs()
|
|
except Exception as e:
|
|
logger.exception(f"An exception occurred while running container: {e}.")
|
|
pass
|
|
finally:
|
|
# Cleanup the export image
|
|
logger.info("Cleaning image %s", image.id)
|
|
try:
|
|
image.remove(force=True)
|
|
except NotFound:
|
|
pass
|
|
except Exception as e:
|
|
logger.error("Error while removing image %s, skipping", image.id)
|
|
logger.exception(e)
|
|
|
|
|
|
@pytest.fixture(scope="session", params=MODEL_CONFIGURATIONS.keys())
|
|
def neuron_model_config(request):
|
|
"""Expose a pre-trained neuron model
|
|
|
|
The fixture first makes sure the following model artifacts are present on the hub:
|
|
- exported neuron model under optimum-internal-testing/neuron-testing-<name>-<version>,
|
|
- cached artifacts under optimum-internal-testing/neuron-testing-cache.
|
|
If not, it will export the model and push it to the hub.
|
|
|
|
It then fetches the model locally and return a dictionary containing:
|
|
- a configuration name,
|
|
- the original model id,
|
|
- the export parameters,
|
|
- the neuron model id,
|
|
- the neuron model local path.
|
|
|
|
For each exposed model, the local directory is maintained for the duration of the
|
|
test session and cleaned up afterwards.
|
|
The hub model artifacts are never cleaned up and persist accross sessions.
|
|
They must be cleaned up manually when the optimum-neuron version changes.
|
|
|
|
"""
|
|
config_name = request.param
|
|
model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
|
|
neuron_model_name = get_neuron_model_name(config_name)
|
|
neuron_model_id = f"{TEST_ORGANIZATION}/{neuron_model_name}"
|
|
with TemporaryDirectory() as neuron_model_path:
|
|
hub = huggingface_hub.HfApi()
|
|
if not hub.repo_exists(neuron_model_id):
|
|
# Export the model first
|
|
export_model(config_name, model_config, neuron_model_name)
|
|
logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub")
|
|
hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path)
|
|
# Add dynamic parameters to the model configuration
|
|
model_config["neuron_model_path"] = neuron_model_path
|
|
model_config["neuron_model_id"] = neuron_model_id
|
|
# Also add model configuration name to allow tests to adapt their expectations
|
|
model_config["name"] = config_name
|
|
# Yield instead of returning to keep a reference to the temporary directory.
|
|
# It will go out of scope and be released only once all tests needing the fixture
|
|
# have been completed.
|
|
logger.info(f"{config_name} ready for testing ...")
|
|
yield model_config
|
|
logger.info(f"Done with {config_name}")
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def neuron_model_path(neuron_model_config):
|
|
yield neuron_model_config["neuron_model_path"]
|