mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
test(neuron): refactor to prepare batch export
This commit is contained in:
parent
b0069e0485
commit
70e846d53b
@ -118,10 +118,11 @@ def get_tgi_docker_image():
|
||||
return docker_image
|
||||
|
||||
|
||||
def export_model(config_name, model_config, neuron_model_name):
|
||||
"""Export a neuron model.
|
||||
def maybe_export_model(config_name, model_config):
|
||||
"""Export a neuron model for the specified test configuration.
|
||||
|
||||
The model is exported by a custom image built on the fly from the base TGI image.
|
||||
If the neuron model has not already been compiled and pushed to the hub, it is
|
||||
exported by a custom image built on the fly from the base TGI image.
|
||||
This makes sure the exported model and image are aligned and avoids introducing
|
||||
neuron specific imports in the test suite.
|
||||
|
||||
@ -130,9 +131,15 @@ def export_model(config_name, model_config, neuron_model_name):
|
||||
Used to identify test configurations
|
||||
model_config (`str`):
|
||||
The model configuration for export (includes the original model id)
|
||||
neuron_model_name (`str`):
|
||||
The name of the exported model on the hub
|
||||
"""
|
||||
neuron_model_name = get_neuron_model_name(config_name)
|
||||
neuron_model_id = f"{TEST_ORGANIZATION}/{neuron_model_name}"
|
||||
hub = huggingface_hub.HfApi()
|
||||
if hub.repo_exists(neuron_model_id):
|
||||
logger.info(
|
||||
f"Skipping model export for config {config_name} as {neuron_model_id} already exists"
|
||||
)
|
||||
return neuron_model_id
|
||||
|
||||
client = docker.from_env()
|
||||
|
||||
@ -206,6 +213,7 @@ def export_model(config_name, model_config, neuron_model_name):
|
||||
except Exception as e:
|
||||
logger.error("Error while removing image %s, skipping", image.id)
|
||||
logger.exception(e)
|
||||
return neuron_model_id
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", params=MODEL_CONFIGURATIONS.keys())
|
||||
@ -232,14 +240,11 @@ def neuron_model_config(request):
|
||||
"""
|
||||
config_name = request.param
|
||||
model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
|
||||
neuron_model_name = get_neuron_model_name(config_name)
|
||||
neuron_model_id = f"{TEST_ORGANIZATION}/{neuron_model_name}"
|
||||
# Export the model first (only if needed)
|
||||
neuron_model_id = maybe_export_model(config_name, model_config)
|
||||
with TemporaryDirectory() as neuron_model_path:
|
||||
hub = huggingface_hub.HfApi()
|
||||
if not hub.repo_exists(neuron_model_id):
|
||||
# Export the model first
|
||||
export_model(config_name, model_config, neuron_model_name)
|
||||
logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub")
|
||||
hub = huggingface_hub.HfApi()
|
||||
hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path)
|
||||
# Add dynamic parameters to the model configuration
|
||||
model_config["neuron_model_path"] = neuron_model_path
|
||||
|
Loading…
Reference in New Issue
Block a user