mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
fix cache cleanup
This commit is contained in:
parent
04298e5799
commit
b44097a61b
@ -39,9 +39,11 @@ def cleanup_cache(token: str, cache_dir: str):
|
||||
# Retrieve the size per model for all models used in the CI.
|
||||
size_per_model = {}
|
||||
extension_per_model = {}
|
||||
checkpoints_per_model = {}
|
||||
for model_id, revision in REQUIRED_MODELS.items():
|
||||
print(f"Crawling {model_id}...")
|
||||
model_size = 0
|
||||
checkpoints = {}
|
||||
all_files = huggingface_hub.list_repo_files(
|
||||
model_id,
|
||||
repo_type="model",
|
||||
@ -68,11 +70,13 @@ def cleanup_cache(token: str, cache_dir: str):
|
||||
file_url, token=token
|
||||
)
|
||||
model_size += file_metadata.size * 1e-9 # in GB
|
||||
checkpoints[filename] = file_metadata.size * 1e-9
|
||||
|
||||
size_per_model[model_id] = model_size
|
||||
checkpoints_per_model[model_id] = checkpoints
|
||||
|
||||
total_required_size = sum(size_per_model.values())
|
||||
print(f"Total required disk: {total_required_size:.2f} GB")
|
||||
print(f"Total required disk for checkpoints: {total_required_size:.2f} GB")
|
||||
|
||||
cached_dir = huggingface_hub.scan_cache_dir(cache_dir)
|
||||
|
||||
@ -83,9 +87,20 @@ def cleanup_cache(token: str, cache_dir: str):
|
||||
# Retrieve the SHAs and model ids of other non-necessary models in the cache.
|
||||
for repo in cached_dir.repos:
|
||||
if repo.repo_id in REQUIRED_MODELS:
|
||||
cached_required_size_per_model[repo.repo_id] = (
|
||||
repo.size_on_disk * 1e-9
|
||||
) # in GB
|
||||
cached_required_size_per_model[repo.repo_id] = 0
|
||||
|
||||
for checkpoint in checkpoints_per_model[repo.repo_id]:
|
||||
filepath = huggingface_hub.try_to_load_from_cache(
|
||||
repo.repo_id,
|
||||
checkpoint,
|
||||
cache_dir=cache_dir,
|
||||
revision=REQUIRED_MODELS[repo.repo_id],
|
||||
)
|
||||
|
||||
if isinstance(filepath, str):
|
||||
cached_required_size_per_model[
|
||||
repo.repo_id
|
||||
] += checkpoints_per_model[repo.repo_id][checkpoint]
|
||||
elif repo.repo_type == "model":
|
||||
cache_size_per_model[repo.repo_id] = repo.size_on_disk * 1e-9 # in GB
|
||||
|
||||
@ -96,8 +111,13 @@ def cleanup_cache(token: str, cache_dir: str):
|
||||
|
||||
total_required_cached_size = sum(cached_required_size_per_model.values())
|
||||
total_other_cached_size = sum(cache_size_per_model.values())
|
||||
total_non_cached_required_size = total_required_size - total_required_cached_size
|
||||
|
||||
total_non_cached_required_size = total_required_size - total_required_cached_size
|
||||
assert total_non_cached_required_size >= 0
|
||||
|
||||
print(
|
||||
f"Total non-cached required models size: {total_non_cached_required_size:.2f} GB (to be downloaded)"
|
||||
)
|
||||
print(
|
||||
f"Total HF cached models size: {total_other_cached_size + total_required_cached_size:.2f} GB"
|
||||
)
|
||||
@ -113,7 +133,7 @@ def cleanup_cache(token: str, cache_dir: str):
|
||||
"Not enough space on device to execute the complete CI, please clean up the CI machine"
|
||||
)
|
||||
|
||||
while free_memory < total_non_cached_required_size * 1.05:
|
||||
while free_memory < 10 + total_non_cached_required_size * 1.05:
|
||||
if len(cache_size_per_model) == 0:
|
||||
raise ValueError("This should not happen.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user