better comments

This commit is contained in:
OlivierDehaene 2023-05-03 09:55:17 +02:00
parent 89fc8b4812
commit 2b67bab02a

View File

@ -91,14 +91,12 @@ def download_weights(
except (utils.LocalEntryNotFoundError, FileNotFoundError): except (utils.LocalEntryNotFoundError, FileNotFoundError):
pass pass
is_local_model = ( is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
Path(model_id).exists() "WEIGHTS_CACHE_OVERRIDE", None
and Path(model_id).is_dir() ) is not None
or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
)
if not is_local_model: if not is_local_model:
# Download weights from the hub # Try to download weights from the hub
try: try:
filenames = utils.weight_hub_files(model_id, revision, extension) filenames = utils.weight_hub_files(model_id, revision, extension)
utils.download_weights(filenames, model_id, revision) utils.download_weights(filenames, model_id, revision)
@ -115,11 +113,13 @@ def download_weights(
try: try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
local_pt_files = utils.weight_files(model_id, revision, ".bin") local_pt_files = utils.weight_files(model_id, revision, ".bin")
# No local pytorch weights
except utils.LocalEntryNotFoundError: except utils.LocalEntryNotFoundError:
if extension == ".safetensors": if extension == ".safetensors":
logger.warning( logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. " f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights instead." f"Downloading PyTorch weights."
) )
# Try to see if there are pytorch weights on the hub # Try to see if there are pytorch weights on the hub
@ -128,6 +128,11 @@ def download_weights(
local_pt_files = utils.download_weights(pt_filenames, model_id, revision) local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
if auto_convert: if auto_convert:
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights to safetensors."
)
# Safetensors final filenames # Safetensors final filenames
local_st_files = [ local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"