Another attempt.

This commit is contained in:
Nicolas Patry 2024-05-17 15:12:54 +00:00
parent 52c9ff9aca
commit d4b4c8d42e
2 changed files with 13 additions and 19 deletions

View File

@ -154,17 +154,6 @@ def download_weights(
try: try:
import json import json
# TODO remove this, there's currently logic to avoid detecting this file
# as being enough to say the entire repo is on disk, since we also need the parent model
# We're keeping this potential download to prevent breaking backward compatibilty,
# but it shouldn't be necessary in the flow here.
try:
medusa_head = hf_hub_download(
model_id, revision=revision, filename="medusa_lm_head.safetensors"
)
except Exception:
pass
config = hf_hub_download( config = hf_hub_download(
model_id, revision=revision, filename="config.json" model_id, revision=revision, filename="config.json"
) )
@ -176,17 +165,17 @@ def download_weights(
revision = "main" revision = "main"
try: try:
utils.weight_files(base_model_id, revision, extension) utils.weight_files(base_model_id, revision, extension)
logger.info(
f"Files for parent {base_model_id} are already present on the host. "
"Skipping download."
)
return
# Local files not found # Local files not found
except ( except (
utils.LocalEntryNotFoundError, utils.LocalEntryNotFoundError,
FileNotFoundError, FileNotFoundError,
utils.EntryNotFoundError, utils.EntryNotFoundError,
): ):
logger.info(f"Downloading parent model {base_model_id}")
filenames = utils.weight_hub_files(
base_model_id, revision, extension
)
utils.download_weights(filenames, base_model_id, revision)
pass pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass
@ -235,7 +224,14 @@ def download_weights(
return return
# Local files not found # Local files not found
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass try:
logger.info(f"Downloading parent model {base_model_id}")
filenames = utils.weight_hub_files(
base_model_id, revision, extension
)
utils.download_weights(filenames, base_model_id, revision)
except Exception:
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass

View File

@ -40,7 +40,6 @@ def _weight_hub_files_from_model_info(
and "arguments" not in s.rfilename and "arguments" not in s.rfilename
and "args" not in s.rfilename and "args" not in s.rfilename
and "training" not in s.rfilename and "training" not in s.rfilename
and "medusa_lm_head" not in s.rfilename
] ]
@ -57,7 +56,6 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
and "args" not in f and "args" not in f
and "adapter" not in f and "adapter" not in f
and "training" not in f and "training" not in f
and "medusa_lm_head" not in f
] ]
return filenames return filenames