From 70e846d53b596b2dfa213d5c3fce8569d972c7a5 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 24 Feb 2025 17:36:26 +0000 Subject: [PATCH] test(neuron): refactor to prepare batch export --- integration-tests/fixtures/neuron/model.py | 27 +++++++++++++--------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/integration-tests/fixtures/neuron/model.py b/integration-tests/fixtures/neuron/model.py index 3345e2ea..2d58351c 100644 --- a/integration-tests/fixtures/neuron/model.py +++ b/integration-tests/fixtures/neuron/model.py @@ -118,10 +118,11 @@ def get_tgi_docker_image(): return docker_image -def export_model(config_name, model_config, neuron_model_name): - """Export a neuron model. +def maybe_export_model(config_name, model_config): + """Export a neuron model for the specified test configuration. - The model is exported by a custom image built on the fly from the base TGI image. + If the neuron model has not already been compiled and pushed to the hub, it is + exported by a custom image built on the fly from the base TGI image. This makes sure the exported model and image are aligned and avoids introducing neuron specific imports in the test suite. @@ -130,9 +131,15 @@ def export_model(config_name, model_config, neuron_model_name): Used to identify test configurations model_config (`str`): The model configuration for export (includes the original model id) - neuron_model_name (`str`): - The name of the exported model on the hub """ + neuron_model_name = get_neuron_model_name(config_name) + neuron_model_id = f"{TEST_ORGANIZATION}/{neuron_model_name}" + hub = huggingface_hub.HfApi() + if hub.repo_exists(neuron_model_id): + logger.info( + f"Skipping model export for config {config_name} as {neuron_model_id} already exists" + ) + return neuron_model_id client = docker.from_env() @@ -206,6 +213,7 @@ def export_model(config_name, model_config, neuron_model_name): except Exception as e: logger.error("Error while removing image %s, skipping", image.id) logger.exception(e) + return neuron_model_id @pytest.fixture(scope="session", params=MODEL_CONFIGURATIONS.keys()) @@ -232,14 +240,11 @@ def neuron_model_config(request): """ config_name = request.param model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param]) - neuron_model_name = get_neuron_model_name(config_name) - neuron_model_id = f"{TEST_ORGANIZATION}/{neuron_model_name}" + # Export the model first (only if needed) + neuron_model_id = maybe_export_model(config_name, model_config) with TemporaryDirectory() as neuron_model_path: - hub = huggingface_hub.HfApi() - if not hub.repo_exists(neuron_model_id): - # Export the model first - export_model(config_name, model_config, neuron_model_name) logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub") + hub = huggingface_hub.HfApi() hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path) # Add dynamic parameters to the model configuration model_config["neuron_model_path"] = neuron_model_path