Optional base_name_or_model_path.

This commit is contained in:
Nicolas Patry 2024-05-17 14:20:58 +00:00
parent e5416274df
commit 52c9ff9aca

View File

@ -171,22 +171,23 @@ def download_weights(
with open(config, "r") as f:
config = json.load(f)
model_id = config["base_model_name_or_path"]
revision = "main"
try:
utils.weight_files(model_id, revision, extension)
logger.info(
f"Files for parent {model_id} are already present on the host. "
"Skipping download."
)
return
# Local files not found
except (
utils.LocalEntryNotFoundError,
FileNotFoundError,
utils.EntryNotFoundError,
):
pass
base_model_id = config.get("base_model_name_or_path", None)
if base_model_id:
revision = "main"
try:
utils.weight_files(base_model_id, revision, extension)
logger.info(
f"Files for parent {base_model_id} are already present on the host. "
"Skipping download."
)
return
# Local files not found
except (
utils.LocalEntryNotFoundError,
FileNotFoundError,
utils.EntryNotFoundError,
):
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
@ -222,18 +223,19 @@ def download_weights(
with open(config, "r") as f:
config = json.load(f)
model_id = config["base_model_name_or_path"]
revision = "main"
try:
utils.weight_files(model_id, revision, extension)
logger.info(
f"Files for parent {model_id} are already present on the host. "
"Skipping download."
)
return
# Local files not found
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
base_model_id = config.get("base_model_name_or_path", None)
if base_model_id:
revision = "main"
try:
utils.weight_files(base_model_id, revision, extension)
logger.info(
f"Files for parent {base_model_id} are already present on the host. "
"Skipping download."
)
return
# Local files not found
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass