mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Small fixes in the weights loading logic.
This commit is contained in:
parent
915e5f088c
commit
e69e68c8ea
@ -194,16 +194,12 @@ def download_weights(
|
||||
if not extension == ".safetensors" or not auto_convert:
|
||||
raise e
|
||||
|
||||
elif (Path(model_id) / "medusa_lm_head.pt").exists():
|
||||
elif (Path(model_id) / "medusa_lm_head.safetensors").exists():
|
||||
# Try to load as a local Medusa model
|
||||
try:
|
||||
import json
|
||||
|
||||
medusa_head = Path(model_id) / "medusa_lm_head.pt"
|
||||
if auto_convert:
|
||||
medusa_sf = Path(model_id) / "medusa_lm_head.safetensors"
|
||||
if not medusa_sf.exists():
|
||||
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||
medusa_head = Path(model_id) / "medusa_lm_head.safetensors"
|
||||
medusa_config = Path(model_id) / "config.json"
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
|
@ -40,6 +40,7 @@ def _weight_hub_files_from_model_info(
|
||||
and "arguments" not in s.rfilename
|
||||
and "args" not in s.rfilename
|
||||
and "training" not in s.rfilename
|
||||
and "medusa_lm_head" not in s.rfilename
|
||||
]
|
||||
|
||||
|
||||
@ -56,6 +57,7 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
||||
and "args" not in f
|
||||
and "adapter" not in f
|
||||
and "training" not in f
|
||||
and "medusa_lm_head" not in f
|
||||
]
|
||||
return filenames
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user