mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Very ugly code.
This commit is contained in:
parent
d4b4c8d42e
commit
6f3660de3b
@ -161,21 +161,19 @@ def download_weights(
|
|||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
base_model_id = config.get("base_model_name_or_path", None)
|
base_model_id = config.get("base_model_name_or_path", None)
|
||||||
if base_model_id:
|
if base_model_id and base_model_id != model_id:
|
||||||
revision = "main"
|
|
||||||
try:
|
try:
|
||||||
utils.weight_files(base_model_id, revision, extension)
|
|
||||||
# Local files not found
|
|
||||||
except (
|
|
||||||
utils.LocalEntryNotFoundError,
|
|
||||||
FileNotFoundError,
|
|
||||||
utils.EntryNotFoundError,
|
|
||||||
):
|
|
||||||
logger.info(f"Downloading parent model {base_model_id}")
|
logger.info(f"Downloading parent model {base_model_id}")
|
||||||
filenames = utils.weight_hub_files(
|
download_weights(
|
||||||
base_model_id, revision, extension
|
model_id=base_model_id,
|
||||||
|
revision="main",
|
||||||
|
extension=extension,
|
||||||
|
auto_convert=auto_convert,
|
||||||
|
logger_level=logger_level,
|
||||||
|
json_output=json_output,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
utils.download_weights(filenames, base_model_id, revision)
|
except Exception:
|
||||||
pass
|
pass
|
||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
pass
|
pass
|
||||||
@ -214,34 +212,32 @@ def download_weights(
|
|||||||
|
|
||||||
base_model_id = config.get("base_model_name_or_path", None)
|
base_model_id = config.get("base_model_name_or_path", None)
|
||||||
if base_model_id:
|
if base_model_id:
|
||||||
revision = "main"
|
|
||||||
try:
|
try:
|
||||||
utils.weight_files(base_model_id, revision, extension)
|
logger.info(f"Downloading parent model {base_model_id}")
|
||||||
logger.info(
|
download_weights(
|
||||||
f"Files for parent {base_model_id} are already present on the host. "
|
model_id=base_model_id,
|
||||||
"Skipping download."
|
revision="main",
|
||||||
|
extension=extension,
|
||||||
|
auto_convert=auto_convert,
|
||||||
|
logger_level=logger_level,
|
||||||
|
json_output=json_output,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
return
|
except Exception:
|
||||||
# Local files not found
|
pass
|
||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
|
||||||
try:
|
|
||||||
logger.info(f"Downloading parent model {base_model_id}")
|
|
||||||
filenames = utils.weight_hub_files(
|
|
||||||
base_model_id, revision, extension
|
|
||||||
)
|
|
||||||
utils.download_weights(filenames, base_model_id, revision)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Try to see if there are local pytorch weights
|
# Try to see if there are local pytorch weights
|
||||||
try:
|
try:
|
||||||
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||||
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
try:
|
||||||
|
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
||||||
|
except Exception:
|
||||||
|
local_pt_files = utils.weight_files(model_id, revision, ".pt")
|
||||||
|
|
||||||
# No local pytorch weights
|
# No local pytorch weights
|
||||||
except utils.LocalEntryNotFoundError:
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
if extension == ".safetensors":
|
if extension == ".safetensors":
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||||
|
Loading…
Reference in New Issue
Block a user