better comments

This commit is contained in:
OlivierDehaene 2023-05-03 09:55:17 +02:00
parent 89fc8b4812
commit 2b67bab02a

View File

@ -91,14 +91,12 @@ def download_weights(
except (utils.LocalEntryNotFoundError, FileNotFoundError):
pass
is_local_model = (
Path(model_id).exists()
and Path(model_id).is_dir()
or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
)
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
"WEIGHTS_CACHE_OVERRIDE", None
) is not None
if not is_local_model:
# Download weights from the hub
# Try to download weights from the hub
try:
filenames = utils.weight_hub_files(model_id, revision, extension)
utils.download_weights(filenames, model_id, revision)
@ -115,11 +113,13 @@ def download_weights(
try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
local_pt_files = utils.weight_files(model_id, revision, ".bin")
# No local pytorch weights
except utils.LocalEntryNotFoundError:
if extension == ".safetensors":
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights instead."
f"Downloading PyTorch weights."
)
# Try to see if there are pytorch weights on the hub
@ -128,6 +128,11 @@ def download_weights(
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
if auto_convert:
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights to safetensors."
)
# Safetensors final filenames
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"