mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
better comments
This commit is contained in:
parent
89fc8b4812
commit
2b67bab02a
@ -91,14 +91,12 @@ def download_weights(
|
||||
except (utils.LocalEntryNotFoundError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
is_local_model = (
|
||||
Path(model_id).exists()
|
||||
and Path(model_id).is_dir()
|
||||
or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
)
|
||||
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
|
||||
"WEIGHTS_CACHE_OVERRIDE", None
|
||||
) is not None
|
||||
|
||||
if not is_local_model:
|
||||
# Download weights from the hub
|
||||
# Try to download weights from the hub
|
||||
try:
|
||||
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||
utils.download_weights(filenames, model_id, revision)
|
||||
@ -115,11 +113,13 @@ def download_weights(
|
||||
try:
|
||||
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
||||
|
||||
# No local pytorch weights
|
||||
except utils.LocalEntryNotFoundError:
|
||||
if extension == ".safetensors":
|
||||
logger.warning(
|
||||
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||
f"Converting PyTorch weights instead."
|
||||
f"Downloading PyTorch weights."
|
||||
)
|
||||
|
||||
# Try to see if there are pytorch weights on the hub
|
||||
@ -128,6 +128,11 @@ def download_weights(
|
||||
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
||||
|
||||
if auto_convert:
|
||||
logger.warning(
|
||||
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||
f"Converting PyTorch weights to safetensors."
|
||||
)
|
||||
|
||||
# Safetensors final filenames
|
||||
local_st_files = [
|
||||
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
||||
|
Loading…
Reference in New Issue
Block a user