fix cache cleanup

This commit is contained in:
Felix Marty 2024-06-26 10:02:45 +00:00
parent 04298e5799
commit b44097a61b

View File

@ -39,9 +39,11 @@ def cleanup_cache(token: str, cache_dir: str):
# Retrieve the size per model for all models used in the CI.
size_per_model = {}
extension_per_model = {}
checkpoints_per_model = {}
for model_id, revision in REQUIRED_MODELS.items():
print(f"Crawling {model_id}...")
model_size = 0
checkpoints = {}
all_files = huggingface_hub.list_repo_files(
model_id,
repo_type="model",
@ -68,11 +70,13 @@ def cleanup_cache(token: str, cache_dir: str):
file_url, token=token
)
model_size += file_metadata.size * 1e-9 # in GB
checkpoints[filename] = file_metadata.size * 1e-9
size_per_model[model_id] = model_size
checkpoints_per_model[model_id] = checkpoints
total_required_size = sum(size_per_model.values())
print(f"Total required disk: {total_required_size:.2f} GB")
print(f"Total required disk for checkpoints: {total_required_size:.2f} GB")
cached_dir = huggingface_hub.scan_cache_dir(cache_dir)
@ -83,9 +87,20 @@ def cleanup_cache(token: str, cache_dir: str):
# Retrieve the SHAs and model ids of other non-necessary models in the cache.
for repo in cached_dir.repos:
if repo.repo_id in REQUIRED_MODELS:
cached_required_size_per_model[repo.repo_id] = (
repo.size_on_disk * 1e-9
) # in GB
cached_required_size_per_model[repo.repo_id] = 0
for checkpoint in checkpoints_per_model[repo.repo_id]:
filepath = huggingface_hub.try_to_load_from_cache(
repo.repo_id,
checkpoint,
cache_dir=cache_dir,
revision=REQUIRED_MODELS[repo.repo_id],
)
if isinstance(filepath, str):
cached_required_size_per_model[
repo.repo_id
] += checkpoints_per_model[repo.repo_id][checkpoint]
elif repo.repo_type == "model":
cache_size_per_model[repo.repo_id] = repo.size_on_disk * 1e-9 # in GB
@ -96,8 +111,13 @@ def cleanup_cache(token: str, cache_dir: str):
total_required_cached_size = sum(cached_required_size_per_model.values())
total_other_cached_size = sum(cache_size_per_model.values())
total_non_cached_required_size = total_required_size - total_required_cached_size
total_non_cached_required_size = total_required_size - total_required_cached_size
assert total_non_cached_required_size >= 0
print(
f"Total non-cached required models size: {total_non_cached_required_size:.2f} GB (to be downloaded)"
)
print(
f"Total HF cached models size: {total_other_cached_size + total_required_cached_size:.2f} GB"
)
@ -113,7 +133,7 @@ def cleanup_cache(token: str, cache_dir: str):
"Not enough space on device to execute the complete CI, please clean up the CI machine"
)
while free_memory < total_non_cached_required_size * 1.05:
while free_memory < 10 + total_non_cached_required_size * 1.05:
if len(cache_size_per_model) == 0:
raise ValueError("This should not happen.")