mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing the download strategy for ibm-fms
This commit is contained in:
parent
b5f1c9de06
commit
e5416274df
@ -154,13 +154,21 @@ def download_weights(
|
|||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
medusa_head = hf_hub_download(
|
# TODO remove this, there's currently logic to avoid detecting this file
|
||||||
model_id, revision=revision, filename="medusa_lm_head.safetensors"
|
# as being enough to say the entire repo is on disk, since we also need the parent model
|
||||||
)
|
# We're keeping this potential download to prevent breaking backward compatibilty,
|
||||||
medusa_config = hf_hub_download(
|
# but it shouldn't be necessary in the flow here.
|
||||||
|
try:
|
||||||
|
medusa_head = hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="medusa_lm_head.safetensors"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
config = hf_hub_download(
|
||||||
model_id, revision=revision, filename="config.json"
|
model_id, revision=revision, filename="config.json"
|
||||||
)
|
)
|
||||||
with open(medusa_config, "r") as f:
|
with open(config, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
model_id = config["base_model_name_or_path"]
|
model_id = config["base_model_name_or_path"]
|
||||||
@ -195,14 +203,23 @@ def download_weights(
|
|||||||
if not extension == ".safetensors" or not auto_convert:
|
if not extension == ".safetensors" or not auto_convert:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
elif (Path(model_id) / "medusa_lm_head.safetensors").exists():
|
elif (Path(model_id) / "adapter_config.json").exists():
|
||||||
|
# Try to load as a local PEFT model
|
||||||
|
try:
|
||||||
|
utils.download_and_unload_peft(
|
||||||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
return
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
elif (Path(model_id) / "config.json").exists():
|
||||||
# Try to load as a local Medusa model
|
# Try to load as a local Medusa model
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
medusa_head = Path(model_id) / "medusa_lm_head.safetensors"
|
config = Path(model_id) / "config.json"
|
||||||
medusa_config = Path(model_id) / "config.json"
|
with open(config, "r") as f:
|
||||||
with open(medusa_config, "r") as f:
|
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
model_id = config["base_model_name_or_path"]
|
model_id = config["base_model_name_or_path"]
|
||||||
@ -220,17 +237,6 @@ def download_weights(
|
|||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif (Path(model_id) / "adapter_config.json").exists():
|
|
||||||
# Try to load as a local PEFT model
|
|
||||||
try:
|
|
||||||
utils.download_and_unload_peft(
|
|
||||||
model_id, revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
utils.weight_files(model_id, revision, extension)
|
|
||||||
return
|
|
||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try to see if there are local pytorch weights
|
# Try to see if there are local pytorch weights
|
||||||
try:
|
try:
|
||||||
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||||
|
Loading…
Reference in New Issue
Block a user