Fixing the download strategy for ibm-fms

This commit is contained in:
Nicolas Patry 2024-05-17 10:33:00 +00:00
parent b5f1c9de06
commit e5416274df

View File

@ -154,13 +154,21 @@ def download_weights(
try:
import json
# TODO remove this, there's currently logic to avoid detecting this file
# as being enough to say the entire repo is on disk, since we also need the parent model
# We're keeping this potential download to prevent breaking backward compatibilty,
# but it shouldn't be necessary in the flow here.
try:
medusa_head = hf_hub_download(
model_id, revision=revision, filename="medusa_lm_head.safetensors"
)
medusa_config = hf_hub_download(
except Exception:
pass
config = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(medusa_config, "r") as f:
with open(config, "r") as f:
config = json.load(f)
model_id = config["base_model_name_or_path"]
@ -195,14 +203,23 @@ def download_weights(
if not extension == ".safetensors" or not auto_convert:
raise e
elif (Path(model_id) / "medusa_lm_head.safetensors").exists():
elif (Path(model_id) / "adapter_config.json").exists():
# Try to load as a local PEFT model
try:
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
elif (Path(model_id) / "config.json").exists():
# Try to load as a local Medusa model
try:
import json
medusa_head = Path(model_id) / "medusa_lm_head.safetensors"
medusa_config = Path(model_id) / "config.json"
with open(medusa_config, "r") as f:
config = Path(model_id) / "config.json"
with open(config, "r") as f:
config = json.load(f)
model_id = config["base_model_name_or_path"]
@ -220,17 +237,6 @@ def download_weights(
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
elif (Path(model_id) / "adapter_config.json").exists():
# Try to load as a local PEFT model
try:
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# Try to see if there are local pytorch weights
try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE