mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Another attempt.
This commit is contained in:
parent
52c9ff9aca
commit
d4b4c8d42e
@ -154,17 +154,6 @@ def download_weights(
|
|||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# TODO remove this, there's currently logic to avoid detecting this file
|
|
||||||
# as being enough to say the entire repo is on disk, since we also need the parent model
|
|
||||||
# We're keeping this potential download to prevent breaking backward compatibilty,
|
|
||||||
# but it shouldn't be necessary in the flow here.
|
|
||||||
try:
|
|
||||||
medusa_head = hf_hub_download(
|
|
||||||
model_id, revision=revision, filename="medusa_lm_head.safetensors"
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
config = hf_hub_download(
|
config = hf_hub_download(
|
||||||
model_id, revision=revision, filename="config.json"
|
model_id, revision=revision, filename="config.json"
|
||||||
)
|
)
|
||||||
@ -176,17 +165,17 @@ def download_weights(
|
|||||||
revision = "main"
|
revision = "main"
|
||||||
try:
|
try:
|
||||||
utils.weight_files(base_model_id, revision, extension)
|
utils.weight_files(base_model_id, revision, extension)
|
||||||
logger.info(
|
|
||||||
f"Files for parent {base_model_id} are already present on the host. "
|
|
||||||
"Skipping download."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
# Local files not found
|
# Local files not found
|
||||||
except (
|
except (
|
||||||
utils.LocalEntryNotFoundError,
|
utils.LocalEntryNotFoundError,
|
||||||
FileNotFoundError,
|
FileNotFoundError,
|
||||||
utils.EntryNotFoundError,
|
utils.EntryNotFoundError,
|
||||||
):
|
):
|
||||||
|
logger.info(f"Downloading parent model {base_model_id}")
|
||||||
|
filenames = utils.weight_hub_files(
|
||||||
|
base_model_id, revision, extension
|
||||||
|
)
|
||||||
|
utils.download_weights(filenames, base_model_id, revision)
|
||||||
pass
|
pass
|
||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
pass
|
pass
|
||||||
@ -235,7 +224,14 @@ def download_weights(
|
|||||||
return
|
return
|
||||||
# Local files not found
|
# Local files not found
|
||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
pass
|
try:
|
||||||
|
logger.info(f"Downloading parent model {base_model_id}")
|
||||||
|
filenames = utils.weight_hub_files(
|
||||||
|
base_model_id, revision, extension
|
||||||
|
)
|
||||||
|
utils.download_weights(filenames, base_model_id, revision)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -40,7 +40,6 @@ def _weight_hub_files_from_model_info(
|
|||||||
and "arguments" not in s.rfilename
|
and "arguments" not in s.rfilename
|
||||||
and "args" not in s.rfilename
|
and "args" not in s.rfilename
|
||||||
and "training" not in s.rfilename
|
and "training" not in s.rfilename
|
||||||
and "medusa_lm_head" not in s.rfilename
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -57,7 +56,6 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|||||||
and "args" not in f
|
and "args" not in f
|
||||||
and "adapter" not in f
|
and "adapter" not in f
|
||||||
and "training" not in f
|
and "training" not in f
|
||||||
and "medusa_lm_head" not in f
|
|
||||||
]
|
]
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user