Very ugly code.

This commit is contained in:
Nicolas Patry 2024-05-17 15:59:23 +00:00
parent d4b4c8d42e
commit 6f3660de3b

View File

@ -161,21 +161,19 @@ def download_weights(
config = json.load(f) config = json.load(f)
base_model_id = config.get("base_model_name_or_path", None) base_model_id = config.get("base_model_name_or_path", None)
if base_model_id: if base_model_id and base_model_id != model_id:
revision = "main"
try: try:
utils.weight_files(base_model_id, revision, extension)
# Local files not found
except (
utils.LocalEntryNotFoundError,
FileNotFoundError,
utils.EntryNotFoundError,
):
logger.info(f"Downloading parent model {base_model_id}") logger.info(f"Downloading parent model {base_model_id}")
filenames = utils.weight_hub_files( download_weights(
base_model_id, revision, extension model_id=base_model_id,
revision="main",
extension=extension,
auto_convert=auto_convert,
logger_level=logger_level,
json_output=json_output,
trust_remote_code=trust_remote_code,
) )
utils.download_weights(filenames, base_model_id, revision) except Exception:
pass pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass
@ -214,34 +212,32 @@ def download_weights(
base_model_id = config.get("base_model_name_or_path", None) base_model_id = config.get("base_model_name_or_path", None)
if base_model_id: if base_model_id:
revision = "main"
try: try:
utils.weight_files(base_model_id, revision, extension) logger.info(f"Downloading parent model {base_model_id}")
logger.info( download_weights(
f"Files for parent {base_model_id} are already present on the host. " model_id=base_model_id,
"Skipping download." revision="main",
extension=extension,
auto_convert=auto_convert,
logger_level=logger_level,
json_output=json_output,
trust_remote_code=trust_remote_code,
) )
return except Exception:
# Local files not found pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
try:
logger.info(f"Downloading parent model {base_model_id}")
filenames = utils.weight_hub_files(
base_model_id, revision, extension
)
utils.download_weights(filenames, base_model_id, revision)
except Exception:
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass
# Try to see if there are local pytorch weights # Try to see if there are local pytorch weights
try: try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
local_pt_files = utils.weight_files(model_id, revision, ".bin") try:
local_pt_files = utils.weight_files(model_id, revision, ".bin")
except Exception:
local_pt_files = utils.weight_files(model_id, revision, ".pt")
# No local pytorch weights # No local pytorch weights
except utils.LocalEntryNotFoundError: except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
if extension == ".safetensors": if extension == ".safetensors":
logger.warning( logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. " f"No safetensors weights found for model {model_id} at revision {revision}. "