Optional base_name_or_model_path.

This commit is contained in:
Nicolas Patry 2024-05-17 14:20:58 +00:00
parent e5416274df
commit 52c9ff9aca

View File

@ -171,12 +171,13 @@ def download_weights(
with open(config, "r") as f:
config = json.load(f)
model_id = config["base_model_name_or_path"]
base_model_id = config.get("base_model_name_or_path", None)
if base_model_id:
revision = "main"
try:
utils.weight_files(model_id, revision, extension)
utils.weight_files(base_model_id, revision, extension)
logger.info(
f"Files for parent {model_id} are already present on the host. "
f"Files for parent {base_model_id} are already present on the host. "
"Skipping download."
)
return
@ -222,12 +223,13 @@ def download_weights(
with open(config, "r") as f:
config = json.load(f)
model_id = config["base_model_name_or_path"]
base_model_id = config.get("base_model_name_or_path", None)
if base_model_id:
revision = "main"
try:
utils.weight_files(model_id, revision, extension)
utils.weight_files(base_model_id, revision, extension)
logger.info(
f"Files for parent {model_id} are already present on the host. "
f"Files for parent {base_model_id} are already present on the host. "
"Skipping download."
)
return