Very ugly code.

This commit is contained in:
Nicolas Patry 2024-05-17 15:59:23 +00:00
parent d4b4c8d42e
commit 6f3660de3b

View File

@ -161,21 +161,19 @@ def download_weights(
config = json.load(f)
base_model_id = config.get("base_model_name_or_path", None)
if base_model_id:
revision = "main"
if base_model_id and base_model_id != model_id:
try:
utils.weight_files(base_model_id, revision, extension)
# Local files not found
except (
utils.LocalEntryNotFoundError,
FileNotFoundError,
utils.EntryNotFoundError,
):
logger.info(f"Downloading parent model {base_model_id}")
filenames = utils.weight_hub_files(
base_model_id, revision, extension
download_weights(
model_id=base_model_id,
revision="main",
extension=extension,
auto_convert=auto_convert,
logger_level=logger_level,
json_output=json_output,
trust_remote_code=trust_remote_code,
)
utils.download_weights(filenames, base_model_id, revision)
except Exception:
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
@ -214,22 +212,17 @@ def download_weights(
base_model_id = config.get("base_model_name_or_path", None)
if base_model_id:
revision = "main"
try:
utils.weight_files(base_model_id, revision, extension)
logger.info(
f"Files for parent {base_model_id} are already present on the host. "
"Skipping download."
)
return
# Local files not found
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
try:
logger.info(f"Downloading parent model {base_model_id}")
filenames = utils.weight_hub_files(
base_model_id, revision, extension
download_weights(
model_id=base_model_id,
revision="main",
extension=extension,
auto_convert=auto_convert,
logger_level=logger_level,
json_output=json_output,
trust_remote_code=trust_remote_code,
)
utils.download_weights(filenames, base_model_id, revision)
except Exception:
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
@ -238,10 +231,13 @@ def download_weights(
# Try to see if there are local pytorch weights
try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
try:
local_pt_files = utils.weight_files(model_id, revision, ".bin")
except Exception:
local_pt_files = utils.weight_files(model_id, revision, ".pt")
# No local pytorch weights
except utils.LocalEntryNotFoundError:
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
if extension == ".safetensors":
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "