Kepp code style consistent with PR #1419

This commit is contained in:
PYNing 2024-01-10 10:58:56 +08:00
parent 6545383861
commit 19aa5308cb
2 changed files with 6 additions and 6 deletions

View File

@ -198,17 +198,17 @@ def download_weights(
if not extension == ".safetensors" or not auto_convert:
raise e
elif Path(model_id).joinpath("medusa_lm_head.pt").exists():
elif (Path(model_id) / "medusa_lm_head.pt").exists():
# Try to load as a local Medusa model
try:
import json
medusa_head = Path(model_id).joinpath("medusa_lm_head.pt")
medusa_head = Path(model_id) / "medusa_lm_head.pt"
if auto_convert:
medusa_sf = Path(model_id).joinpath("medusa_lm_head.safetensors")
medusa_sf = Path(model_id) / "medusa_lm_head.safetensors"
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = Path(model_id).joinpath("config.json")
medusa_config = Path(model_id) / "config.json"
with open(medusa_config, "r") as f:
config = json.load(f)

View File

@ -86,8 +86,8 @@ class FlashLlama(FlashCausalLM):
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa).joinpath("config.json"))
medusa_head = str(Path(use_medusa).joinpath("medusa_lm_head.pt"))
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)