mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Very ugly code.
This commit is contained in:
parent
d4b4c8d42e
commit
6f3660de3b
@ -161,21 +161,19 @@ def download_weights(
|
||||
config = json.load(f)
|
||||
|
||||
base_model_id = config.get("base_model_name_or_path", None)
|
||||
if base_model_id:
|
||||
revision = "main"
|
||||
if base_model_id and base_model_id != model_id:
|
||||
try:
|
||||
utils.weight_files(base_model_id, revision, extension)
|
||||
# Local files not found
|
||||
except (
|
||||
utils.LocalEntryNotFoundError,
|
||||
FileNotFoundError,
|
||||
utils.EntryNotFoundError,
|
||||
):
|
||||
logger.info(f"Downloading parent model {base_model_id}")
|
||||
filenames = utils.weight_hub_files(
|
||||
base_model_id, revision, extension
|
||||
download_weights(
|
||||
model_id=base_model_id,
|
||||
revision="main",
|
||||
extension=extension,
|
||||
auto_convert=auto_convert,
|
||||
logger_level=logger_level,
|
||||
json_output=json_output,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
utils.download_weights(filenames, base_model_id, revision)
|
||||
except Exception:
|
||||
pass
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
@ -214,22 +212,17 @@ def download_weights(
|
||||
|
||||
base_model_id = config.get("base_model_name_or_path", None)
|
||||
if base_model_id:
|
||||
revision = "main"
|
||||
try:
|
||||
utils.weight_files(base_model_id, revision, extension)
|
||||
logger.info(
|
||||
f"Files for parent {base_model_id} are already present on the host. "
|
||||
"Skipping download."
|
||||
)
|
||||
return
|
||||
# Local files not found
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
try:
|
||||
logger.info(f"Downloading parent model {base_model_id}")
|
||||
filenames = utils.weight_hub_files(
|
||||
base_model_id, revision, extension
|
||||
download_weights(
|
||||
model_id=base_model_id,
|
||||
revision="main",
|
||||
extension=extension,
|
||||
auto_convert=auto_convert,
|
||||
logger_level=logger_level,
|
||||
json_output=json_output,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
utils.download_weights(filenames, base_model_id, revision)
|
||||
except Exception:
|
||||
pass
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
@ -238,10 +231,13 @@ def download_weights(
|
||||
# Try to see if there are local pytorch weights
|
||||
try:
|
||||
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||
try:
|
||||
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
||||
except Exception:
|
||||
local_pt_files = utils.weight_files(model_id, revision, ".pt")
|
||||
|
||||
# No local pytorch weights
|
||||
except utils.LocalEntryNotFoundError:
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
if extension == ".safetensors":
|
||||
logger.warning(
|
||||
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||
|
Loading…
Reference in New Issue
Block a user