fix(neuron): export models from container in test fixtures

The neuron tests require models to have been previously exported and
cached on the hub. This is done automatically by the neuron.model
fixture the first time the tests are ran for a specific version.
This fixture used to export the models using optimum-neuron directly,
but this package is not necessarily present on the system.
Instead, it is now done through the neuron TGI itself, since it
contains all the tools required to export the models.
Note that since the CI runs docker in docker (dind) it does not seem
possible to share a volume between the CI container and the container
used to export the model.
For that reason, a specific image with a modified entrypoint is built
on-the-fly when a model export is required.
This commit is contained in:
David Corvoysier 2025-02-18 17:47:54 +00:00
parent bb51c5138c
commit ae37890eef

View File

@ -1,17 +1,20 @@
import copy import copy
import logging import logging
import subprocess
import sys import sys
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import huggingface_hub import huggingface_hub
import pytest import pytest
from transformers import AutoTokenizer import docker
import os
import tempfile
from optimum.neuron import NeuronModelForCausalLM from docker.errors import NotFound
from optimum.neuron.utils import synchronize_hub_cache
from optimum.neuron.version import __sdk_version__ as sdk_version
from optimum.neuron.version import __version__ as version TEST_ORGANIZATION = "optimum-internal-testing"
TEST_CACHE_REPO_ID = f"{TEST_ORGANIZATION}/neuron-testing-cache"
HF_TOKEN = huggingface_hub.get_token()
logging.basicConfig( logging.basicConfig(
@ -21,7 +24,6 @@ logging.basicConfig(
) )
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache"
# All model configurations below will be added to the neuron_model_config fixture # All model configurations below will be added to the neuron_model_config fixture
MODEL_CONFIGURATIONS = { MODEL_CONFIGURATIONS = {
@ -48,21 +50,106 @@ MODEL_CONFIGURATIONS = {
} }
def get_hub_neuron_model_id(config_name: str): def get_neuron_model_name(config_name: str):
return f"optimum-internal-testing/neuron-testing-{version}-{sdk_version}-{config_name}" version = get_tgi_docker_image().split(":")[-1]
return f"neuron-tgi-testing-{config_name}-{version}"
def export_model(model_id, export_kwargs, neuron_model_path): def get_tgi_docker_image():
export_command = ["optimum-cli", "export", "neuron", "-m", model_id, "--task", "text-generation"] docker_image = os.getenv("DOCKER_IMAGE", None)
if docker_image is None:
client = docker.from_env()
images = client.images.list(filters={"reference": "text-generation-inference"})
if not images:
raise ValueError("No text-generation-inference image found on this host to run tests.")
docker_image = images[0].tags[0]
return docker_image
def export_model(config_name, model_config, neuron_model_name):
"""Export a neuron model.
The model is exported by a custom image built on the fly from the base TGI image.
This makes sure the exported model and image are aligned and avoids introducing
neuron specific imports in the test suite.
Args:
config_name (`str`):
Used to identify test configurations
model_config (`str`):
The model configuration for export (includes the original model id)
neuron_model_name (`str`):
The name of the exported model on the hub
"""
client = docker.from_env()
env = {"LOG_LEVEL": "info", "CUSTOM_CACHE_REPO": TEST_CACHE_REPO_ID}
if HF_TOKEN is not None:
env["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
env["HF_TOKEN"] = HF_TOKEN
# Create a sub-image to export the model to workaround docker dind issues preventing
# to share a volume from the container running tests
model_id = model_config["model_id"]
export_kwargs = model_config["export_kwargs"]
base_image = get_tgi_docker_image()
export_image = f"neuron-tgi-tests-{config_name}-export-img"
logger.info(f"Building temporary image {export_image} from {base_image}")
with tempfile.TemporaryDirectory() as context_dir:
# Create entrypoint
model_path = "/data/neuron_model"
export_command = f"optimum-cli export neuron -m {model_id} --task text-generation"
for kwarg, value in export_kwargs.items(): for kwarg, value in export_kwargs.items():
export_command.append(f"--{kwarg}") export_command += f" --{kwarg} {str(value)}"
export_command.append(str(value)) export_command += f" {model_path}"
export_command.append(neuron_model_path) entrypoint_content = f"""#!/bin/sh
logger.info(f"Exporting {model_id} with {export_kwargs}") {export_command}
huggingface-cli repo create --organization {TEST_ORGANIZATION} {neuron_model_name}
huggingface-cli upload {TEST_ORGANIZATION}/{neuron_model_name} {model_path} --exclude *.bin *.safetensors
optimum-cli neuron cache synchronize --repo_id {TEST_CACHE_REPO_ID}
"""
with open(os.path.join(context_dir, "entrypoint.sh"), "wb") as f:
f.write(entrypoint_content.encode("utf-8"))
f.flush()
# Create Dockerfile
docker_content = f"""
FROM {base_image}
COPY entrypoint.sh /export-entrypoint.sh
RUN chmod +x /export-entrypoint.sh
ENTRYPOINT ["/export-entrypoint.sh"]
"""
with open(os.path.join(context_dir, "Dockerfile"), "wb") as f:
f.write(docker_content.encode("utf-8"))
f.flush()
image, logs = client.images.build(path=context_dir, dockerfile=f.name, tag=export_image)
logger.info("Successfully built image %s", image.id)
logger.debug("Build logs %s", logs)
try: try:
subprocess.run(export_command, check=True) container = client.containers.run(
except subprocess.CalledProcessError as e: export_image,
raise ValueError(f"Failed to export model: {e}") environment=env,
auto_remove=True,
detach=False,
devices=["/dev/neuron0"],
shm_size="1G",
)
logger.info(f"Successfully exported model for config {config_name}")
container.logs()
except Exception as e:
logger.exception(f"An exception occurred while running container: {e}.")
pass
finally:
# Cleanup the export image
logger.info("Cleaning image %s", image.id)
try:
image.remove(force=True)
except NotFound:
pass
except Exception as e:
logger.error("Error while removing image %s, skipping", image.id)
logger.exception(e)
@pytest.fixture(scope="session", params=MODEL_CONFIGURATIONS.keys()) @pytest.fixture(scope="session", params=MODEL_CONFIGURATIONS.keys())
@ -70,7 +157,7 @@ def neuron_model_config(request):
"""Expose a pre-trained neuron model """Expose a pre-trained neuron model
The fixture first makes sure the following model artifacts are present on the hub: The fixture first makes sure the following model artifacts are present on the hub:
- exported neuron model under optimum-internal-testing/neuron-testing-<version>-<name>, - exported neuron model under optimum-internal-testing/neuron-testing-<name>-<version>,
- cached artifacts under optimum-internal-testing/neuron-testing-cache. - cached artifacts under optimum-internal-testing/neuron-testing-cache.
If not, it will export the model and push it to the hub. If not, it will export the model and push it to the hub.
@ -89,28 +176,15 @@ def neuron_model_config(request):
""" """
config_name = request.param config_name = request.param
model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param]) model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
model_id = model_config["model_id"] neuron_model_name = get_neuron_model_name(config_name)
export_kwargs = model_config["export_kwargs"] neuron_model_id = f"{TEST_ORGANIZATION}/{neuron_model_name}"
neuron_model_id = get_hub_neuron_model_id(config_name)
with TemporaryDirectory() as neuron_model_path: with TemporaryDirectory() as neuron_model_path:
hub = huggingface_hub.HfApi() hub = huggingface_hub.HfApi()
if hub.repo_exists(neuron_model_id): if not hub.repo_exists(neuron_model_id):
# Export the model first
export_model(config_name, model_config, neuron_model_name)
logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub") logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub")
hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path) hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path)
else:
export_model(model_id, export_kwargs, neuron_model_path)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.save_pretrained(neuron_model_path)
del tokenizer
# Create the test model on the hub
hub.create_repo(neuron_model_id, private=True)
hub.upload_folder(
folder_path=neuron_model_path,
repo_id=neuron_model_id,
ignore_patterns=[NeuronModelForCausalLM.CHECKPOINT_DIR + "/*"],
)
# Make sure it is cached
synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
# Add dynamic parameters to the model configuration # Add dynamic parameters to the model configuration
model_config["neuron_model_path"] = neuron_model_path model_config["neuron_model_path"] = neuron_model_path
model_config["neuron_model_id"] = neuron_model_id model_config["neuron_model_id"] = neuron_model_id