mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
better comments
This commit is contained in:
parent
89fc8b4812
commit
2b67bab02a
@ -91,14 +91,12 @@ def download_weights(
|
|||||||
except (utils.LocalEntryNotFoundError, FileNotFoundError):
|
except (utils.LocalEntryNotFoundError, FileNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
is_local_model = (
|
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
|
||||||
Path(model_id).exists()
|
"WEIGHTS_CACHE_OVERRIDE", None
|
||||||
and Path(model_id).is_dir()
|
) is not None
|
||||||
or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_local_model:
|
if not is_local_model:
|
||||||
# Download weights from the hub
|
# Try to download weights from the hub
|
||||||
try:
|
try:
|
||||||
filenames = utils.weight_hub_files(model_id, revision, extension)
|
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||||
utils.download_weights(filenames, model_id, revision)
|
utils.download_weights(filenames, model_id, revision)
|
||||||
@ -115,11 +113,13 @@ def download_weights(
|
|||||||
try:
|
try:
|
||||||
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||||
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
||||||
|
|
||||||
|
# No local pytorch weights
|
||||||
except utils.LocalEntryNotFoundError:
|
except utils.LocalEntryNotFoundError:
|
||||||
if extension == ".safetensors":
|
if extension == ".safetensors":
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||||
f"Converting PyTorch weights instead."
|
f"Downloading PyTorch weights."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to see if there are pytorch weights on the hub
|
# Try to see if there are pytorch weights on the hub
|
||||||
@ -128,6 +128,11 @@ def download_weights(
|
|||||||
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
||||||
|
|
||||||
if auto_convert:
|
if auto_convert:
|
||||||
|
logger.warning(
|
||||||
|
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||||
|
f"Converting PyTorch weights to safetensors."
|
||||||
|
)
|
||||||
|
|
||||||
# Safetensors final filenames
|
# Safetensors final filenames
|
||||||
local_st_files = [
|
local_st_files = [
|
||||||
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
||||||
|
Loading…
Reference in New Issue
Block a user