test(neuron): refactor to prepare batch export

This commit is contained in:
David Corvoysier 2025-02-24 17:36:26 +00:00
parent b0069e0485
commit 70e846d53b

View File

@ -118,10 +118,11 @@ def get_tgi_docker_image():
return docker_image return docker_image
def export_model(config_name, model_config, neuron_model_name): def maybe_export_model(config_name, model_config):
"""Export a neuron model. """Export a neuron model for the specified test configuration.
The model is exported by a custom image built on the fly from the base TGI image. If the neuron model has not already been compiled and pushed to the hub, it is
exported by a custom image built on the fly from the base TGI image.
This makes sure the exported model and image are aligned and avoids introducing This makes sure the exported model and image are aligned and avoids introducing
neuron specific imports in the test suite. neuron specific imports in the test suite.
@ -130,9 +131,15 @@ def export_model(config_name, model_config, neuron_model_name):
Used to identify test configurations Used to identify test configurations
model_config (`str`): model_config (`str`):
The model configuration for export (includes the original model id) The model configuration for export (includes the original model id)
neuron_model_name (`str`):
The name of the exported model on the hub
""" """
neuron_model_name = get_neuron_model_name(config_name)
neuron_model_id = f"{TEST_ORGANIZATION}/{neuron_model_name}"
hub = huggingface_hub.HfApi()
if hub.repo_exists(neuron_model_id):
logger.info(
f"Skipping model export for config {config_name} as {neuron_model_id} already exists"
)
return neuron_model_id
client = docker.from_env() client = docker.from_env()
@ -206,6 +213,7 @@ def export_model(config_name, model_config, neuron_model_name):
except Exception as e: except Exception as e:
logger.error("Error while removing image %s, skipping", image.id) logger.error("Error while removing image %s, skipping", image.id)
logger.exception(e) logger.exception(e)
return neuron_model_id
@pytest.fixture(scope="session", params=MODEL_CONFIGURATIONS.keys()) @pytest.fixture(scope="session", params=MODEL_CONFIGURATIONS.keys())
@ -232,14 +240,11 @@ def neuron_model_config(request):
""" """
config_name = request.param config_name = request.param
model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param]) model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
neuron_model_name = get_neuron_model_name(config_name) # Export the model first (only if needed)
neuron_model_id = f"{TEST_ORGANIZATION}/{neuron_model_name}" neuron_model_id = maybe_export_model(config_name, model_config)
with TemporaryDirectory() as neuron_model_path: with TemporaryDirectory() as neuron_model_path:
hub = huggingface_hub.HfApi()
if not hub.repo_exists(neuron_model_id):
# Export the model first
export_model(config_name, model_config, neuron_model_name)
logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub") logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub")
hub = huggingface_hub.HfApi()
hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path) hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path)
# Add dynamic parameters to the model configuration # Add dynamic parameters to the model configuration
model_config["neuron_model_path"] = neuron_model_path model_config["neuron_model_path"] = neuron_model_path