Optional base_name_or_model_path.

This commit is contained in:
Nicolas Patry 2024-05-17 14:20:58 +00:00
parent e5416274df
commit 52c9ff9aca

View File

@ -171,22 +171,23 @@ def download_weights(
with open(config, "r") as f: with open(config, "r") as f:
config = json.load(f) config = json.load(f)
model_id = config["base_model_name_or_path"] base_model_id = config.get("base_model_name_or_path", None)
revision = "main" if base_model_id:
try: revision = "main"
utils.weight_files(model_id, revision, extension) try:
logger.info( utils.weight_files(base_model_id, revision, extension)
f"Files for parent {model_id} are already present on the host. " logger.info(
"Skipping download." f"Files for parent {base_model_id} are already present on the host. "
) "Skipping download."
return )
# Local files not found return
except ( # Local files not found
utils.LocalEntryNotFoundError, except (
FileNotFoundError, utils.LocalEntryNotFoundError,
utils.EntryNotFoundError, FileNotFoundError,
): utils.EntryNotFoundError,
pass ):
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass
@ -222,18 +223,19 @@ def download_weights(
with open(config, "r") as f: with open(config, "r") as f:
config = json.load(f) config = json.load(f)
model_id = config["base_model_name_or_path"] base_model_id = config.get("base_model_name_or_path", None)
revision = "main" if base_model_id:
try: revision = "main"
utils.weight_files(model_id, revision, extension) try:
logger.info( utils.weight_files(base_model_id, revision, extension)
f"Files for parent {model_id} are already present on the host. " logger.info(
"Skipping download." f"Files for parent {base_model_id} are already present on the host. "
) "Skipping download."
return )
# Local files not found return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): # Local files not found
pass except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass