From b44097a61beb48d4b64e04c9e4e8803bb8052134 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:02:45 +0000 Subject: [PATCH] fix cache cleanup --- integration-tests/clean_cache_and_download.py | 32 +++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/integration-tests/clean_cache_and_download.py b/integration-tests/clean_cache_and_download.py index c2a3960c..7d843b76 100644 --- a/integration-tests/clean_cache_and_download.py +++ b/integration-tests/clean_cache_and_download.py @@ -39,9 +39,11 @@ def cleanup_cache(token: str, cache_dir: str): # Retrieve the size per model for all models used in the CI. size_per_model = {} extension_per_model = {} + checkpoints_per_model = {} for model_id, revision in REQUIRED_MODELS.items(): print(f"Crawling {model_id}...") model_size = 0 + checkpoints = {} all_files = huggingface_hub.list_repo_files( model_id, repo_type="model", @@ -68,11 +70,13 @@ def cleanup_cache(token: str, cache_dir: str): file_url, token=token ) model_size += file_metadata.size * 1e-9 # in GB + checkpoints[filename] = file_metadata.size * 1e-9 size_per_model[model_id] = model_size + checkpoints_per_model[model_id] = checkpoints total_required_size = sum(size_per_model.values()) - print(f"Total required disk: {total_required_size:.2f} GB") + print(f"Total required disk for checkpoints: {total_required_size:.2f} GB") cached_dir = huggingface_hub.scan_cache_dir(cache_dir) @@ -83,9 +87,20 @@ def cleanup_cache(token: str, cache_dir: str): # Retrieve the SHAs and model ids of other non-necessary models in the cache. for repo in cached_dir.repos: if repo.repo_id in REQUIRED_MODELS: - cached_required_size_per_model[repo.repo_id] = ( - repo.size_on_disk * 1e-9 - ) # in GB + cached_required_size_per_model[repo.repo_id] = 0 + + for checkpoint in checkpoints_per_model[repo.repo_id]: + filepath = huggingface_hub.try_to_load_from_cache( + repo.repo_id, + checkpoint, + cache_dir=cache_dir, + revision=REQUIRED_MODELS[repo.repo_id], + ) + + if isinstance(filepath, str): + cached_required_size_per_model[ + repo.repo_id + ] += checkpoints_per_model[repo.repo_id][checkpoint] elif repo.repo_type == "model": cache_size_per_model[repo.repo_id] = repo.size_on_disk * 1e-9 # in GB @@ -96,8 +111,13 @@ def cleanup_cache(token: str, cache_dir: str): total_required_cached_size = sum(cached_required_size_per_model.values()) total_other_cached_size = sum(cache_size_per_model.values()) - total_non_cached_required_size = total_required_size - total_required_cached_size + total_non_cached_required_size = total_required_size - total_required_cached_size + assert total_non_cached_required_size >= 0 + + print( + f"Total non-cached required models size: {total_non_cached_required_size:.2f} GB (to be downloaded)" + ) print( f"Total HF cached models size: {total_other_cached_size + total_required_cached_size:.2f} GB" ) @@ -113,7 +133,7 @@ def cleanup_cache(token: str, cache_dir: str): "Not enough space on device to execute the complete CI, please clean up the CI machine" ) - while free_memory < total_non_cached_required_size * 1.05: + while free_memory < 10 + total_non_cached_required_size * 1.05: if len(cache_size_per_model) == 0: raise ValueError("This should not happen.")