Another attempt.

This commit is contained in:
Nicolas Patry 2024-05-17 15:12:54 +00:00
parent 52c9ff9aca
commit d4b4c8d42e
2 changed files with 13 additions and 19 deletions

View File

@ -154,17 +154,6 @@ def download_weights(
try:
import json
# TODO remove this, there's currently logic to avoid detecting this file
# as being enough to say the entire repo is on disk, since we also need the parent model
# We're keeping this potential download to prevent breaking backward compatibilty,
# but it shouldn't be necessary in the flow here.
try:
medusa_head = hf_hub_download(
model_id, revision=revision, filename="medusa_lm_head.safetensors"
)
except Exception:
pass
config = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
@ -176,17 +165,17 @@ def download_weights(
revision = "main"
try:
utils.weight_files(base_model_id, revision, extension)
logger.info(
f"Files for parent {base_model_id} are already present on the host. "
"Skipping download."
)
return
# Local files not found
except (
utils.LocalEntryNotFoundError,
FileNotFoundError,
utils.EntryNotFoundError,
):
logger.info(f"Downloading parent model {base_model_id}")
filenames = utils.weight_hub_files(
base_model_id, revision, extension
)
utils.download_weights(filenames, base_model_id, revision)
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
@ -235,6 +224,13 @@ def download_weights(
return
# Local files not found
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
try:
logger.info(f"Downloading parent model {base_model_id}")
filenames = utils.weight_hub_files(
base_model_id, revision, extension
)
utils.download_weights(filenames, base_model_id, revision)
except Exception:
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass

View File

@ -40,7 +40,6 @@ def _weight_hub_files_from_model_info(
and "arguments" not in s.rfilename
and "args" not in s.rfilename
and "training" not in s.rfilename
and "medusa_lm_head" not in s.rfilename
]
@ -57,7 +56,6 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
and "args" not in f
and "adapter" not in f
and "training" not in f
and "medusa_lm_head" not in f
]
return filenames