Small fixes in the weights loading logic.

This commit is contained in:
Nicolas Patry 2024-02-26 17:32:42 +00:00
parent 915e5f088c
commit e69e68c8ea
2 changed files with 4 additions and 6 deletions

View File

@ -194,16 +194,12 @@ def download_weights(
if not extension == ".safetensors" or not auto_convert: if not extension == ".safetensors" or not auto_convert:
raise e raise e
elif (Path(model_id) / "medusa_lm_head.pt").exists(): elif (Path(model_id) / "medusa_lm_head.safetensors").exists():
# Try to load as a local Medusa model # Try to load as a local Medusa model
try: try:
import json import json
medusa_head = Path(model_id) / "medusa_lm_head.pt" medusa_head = Path(model_id) / "medusa_lm_head.safetensors"
if auto_convert:
medusa_sf = Path(model_id) / "medusa_lm_head.safetensors"
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = Path(model_id) / "config.json" medusa_config = Path(model_id) / "config.json"
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)

View File

@ -40,6 +40,7 @@ def _weight_hub_files_from_model_info(
and "arguments" not in s.rfilename and "arguments" not in s.rfilename
and "args" not in s.rfilename and "args" not in s.rfilename
and "training" not in s.rfilename and "training" not in s.rfilename
and "medusa_lm_head" not in s.rfilename
] ]
@ -56,6 +57,7 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
and "args" not in f and "args" not in f
and "adapter" not in f and "adapter" not in f
and "training" not in f and "training" not in f
and "medusa_lm_head" not in f
] ]
return filenames return filenames