mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Kepp code style consistent with PR #1419
This commit is contained in:
parent
6545383861
commit
19aa5308cb
@ -198,17 +198,17 @@ def download_weights(
|
||||
if not extension == ".safetensors" or not auto_convert:
|
||||
raise e
|
||||
|
||||
elif Path(model_id).joinpath("medusa_lm_head.pt").exists():
|
||||
elif (Path(model_id) / "medusa_lm_head.pt").exists():
|
||||
# Try to load as a local Medusa model
|
||||
try:
|
||||
import json
|
||||
|
||||
medusa_head = Path(model_id).joinpath("medusa_lm_head.pt")
|
||||
medusa_head = Path(model_id) / "medusa_lm_head.pt"
|
||||
if auto_convert:
|
||||
medusa_sf = Path(model_id).joinpath("medusa_lm_head.safetensors")
|
||||
medusa_sf = Path(model_id) / "medusa_lm_head.safetensors"
|
||||
if not medusa_sf.exists():
|
||||
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||
medusa_config = Path(model_id).joinpath("config.json")
|
||||
medusa_config = Path(model_id) / "config.json"
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
@ -86,8 +86,8 @@ class FlashLlama(FlashCausalLM):
|
||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa).joinpath("config.json"))
|
||||
medusa_head = str(Path(use_medusa).joinpath("medusa_lm_head.pt"))
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
|
Loading…
Reference in New Issue
Block a user