fix cache cleanup

This commit is contained in:
Felix Marty 2024-06-26 10:02:45 +00:00
parent 04298e5799
commit b44097a61b

View File

@ -39,9 +39,11 @@ def cleanup_cache(token: str, cache_dir: str):
# Retrieve the size per model for all models used in the CI. # Retrieve the size per model for all models used in the CI.
size_per_model = {} size_per_model = {}
extension_per_model = {} extension_per_model = {}
checkpoints_per_model = {}
for model_id, revision in REQUIRED_MODELS.items(): for model_id, revision in REQUIRED_MODELS.items():
print(f"Crawling {model_id}...") print(f"Crawling {model_id}...")
model_size = 0 model_size = 0
checkpoints = {}
all_files = huggingface_hub.list_repo_files( all_files = huggingface_hub.list_repo_files(
model_id, model_id,
repo_type="model", repo_type="model",
@ -68,11 +70,13 @@ def cleanup_cache(token: str, cache_dir: str):
file_url, token=token file_url, token=token
) )
model_size += file_metadata.size * 1e-9 # in GB model_size += file_metadata.size * 1e-9 # in GB
checkpoints[filename] = file_metadata.size * 1e-9
size_per_model[model_id] = model_size size_per_model[model_id] = model_size
checkpoints_per_model[model_id] = checkpoints
total_required_size = sum(size_per_model.values()) total_required_size = sum(size_per_model.values())
print(f"Total required disk: {total_required_size:.2f} GB") print(f"Total required disk for checkpoints: {total_required_size:.2f} GB")
cached_dir = huggingface_hub.scan_cache_dir(cache_dir) cached_dir = huggingface_hub.scan_cache_dir(cache_dir)
@ -83,9 +87,20 @@ def cleanup_cache(token: str, cache_dir: str):
# Retrieve the SHAs and model ids of other non-necessary models in the cache. # Retrieve the SHAs and model ids of other non-necessary models in the cache.
for repo in cached_dir.repos: for repo in cached_dir.repos:
if repo.repo_id in REQUIRED_MODELS: if repo.repo_id in REQUIRED_MODELS:
cached_required_size_per_model[repo.repo_id] = ( cached_required_size_per_model[repo.repo_id] = 0
repo.size_on_disk * 1e-9
) # in GB for checkpoint in checkpoints_per_model[repo.repo_id]:
filepath = huggingface_hub.try_to_load_from_cache(
repo.repo_id,
checkpoint,
cache_dir=cache_dir,
revision=REQUIRED_MODELS[repo.repo_id],
)
if isinstance(filepath, str):
cached_required_size_per_model[
repo.repo_id
] += checkpoints_per_model[repo.repo_id][checkpoint]
elif repo.repo_type == "model": elif repo.repo_type == "model":
cache_size_per_model[repo.repo_id] = repo.size_on_disk * 1e-9 # in GB cache_size_per_model[repo.repo_id] = repo.size_on_disk * 1e-9 # in GB
@ -96,8 +111,13 @@ def cleanup_cache(token: str, cache_dir: str):
total_required_cached_size = sum(cached_required_size_per_model.values()) total_required_cached_size = sum(cached_required_size_per_model.values())
total_other_cached_size = sum(cache_size_per_model.values()) total_other_cached_size = sum(cache_size_per_model.values())
total_non_cached_required_size = total_required_size - total_required_cached_size
total_non_cached_required_size = total_required_size - total_required_cached_size
assert total_non_cached_required_size >= 0
print(
f"Total non-cached required models size: {total_non_cached_required_size:.2f} GB (to be downloaded)"
)
print( print(
f"Total HF cached models size: {total_other_cached_size + total_required_cached_size:.2f} GB" f"Total HF cached models size: {total_other_cached_size + total_required_cached_size:.2f} GB"
) )
@ -113,7 +133,7 @@ def cleanup_cache(token: str, cache_dir: str):
"Not enough space on device to execute the complete CI, please clean up the CI machine" "Not enough space on device to execute the complete CI, please clean up the CI machine"
) )
while free_memory < total_non_cached_required_size * 1.05: while free_memory < 10 + total_non_cached_required_size * 1.05:
if len(cache_size_per_model) == 0: if len(cache_size_per_model) == 0:
raise ValueError("This should not happen.") raise ValueError("This should not happen.")