mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing the download strategy for ibm-fms
This commit is contained in:
parent
b5f1c9de06
commit
e5416274df
@ -154,13 +154,21 @@ def download_weights(
|
||||
try:
|
||||
import json
|
||||
|
||||
medusa_head = hf_hub_download(
|
||||
model_id, revision=revision, filename="medusa_lm_head.safetensors"
|
||||
)
|
||||
medusa_config = hf_hub_download(
|
||||
# TODO remove this, there's currently logic to avoid detecting this file
|
||||
# as being enough to say the entire repo is on disk, since we also need the parent model
|
||||
# We're keeping this potential download to prevent breaking backward compatibilty,
|
||||
# but it shouldn't be necessary in the flow here.
|
||||
try:
|
||||
medusa_head = hf_hub_download(
|
||||
model_id, revision=revision, filename="medusa_lm_head.safetensors"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
config = hf_hub_download(
|
||||
model_id, revision=revision, filename="config.json"
|
||||
)
|
||||
with open(medusa_config, "r") as f:
|
||||
with open(config, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
model_id = config["base_model_name_or_path"]
|
||||
@ -195,14 +203,23 @@ def download_weights(
|
||||
if not extension == ".safetensors" or not auto_convert:
|
||||
raise e
|
||||
|
||||
elif (Path(model_id) / "medusa_lm_head.safetensors").exists():
|
||||
elif (Path(model_id) / "adapter_config.json").exists():
|
||||
# Try to load as a local PEFT model
|
||||
try:
|
||||
utils.download_and_unload_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
return
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
elif (Path(model_id) / "config.json").exists():
|
||||
# Try to load as a local Medusa model
|
||||
try:
|
||||
import json
|
||||
|
||||
medusa_head = Path(model_id) / "medusa_lm_head.safetensors"
|
||||
medusa_config = Path(model_id) / "config.json"
|
||||
with open(medusa_config, "r") as f:
|
||||
config = Path(model_id) / "config.json"
|
||||
with open(config, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
model_id = config["base_model_name_or_path"]
|
||||
@ -220,17 +237,6 @@ def download_weights(
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
|
||||
elif (Path(model_id) / "adapter_config.json").exists():
|
||||
# Try to load as a local PEFT model
|
||||
try:
|
||||
utils.download_and_unload_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
return
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
|
||||
# Try to see if there are local pytorch weights
|
||||
try:
|
||||
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||
|
Loading…
Reference in New Issue
Block a user