mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Kepp code style consistent with PR #1419
This commit is contained in:
parent
6545383861
commit
19aa5308cb
@ -198,17 +198,17 @@ def download_weights(
|
|||||||
if not extension == ".safetensors" or not auto_convert:
|
if not extension == ".safetensors" or not auto_convert:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
elif Path(model_id).joinpath("medusa_lm_head.pt").exists():
|
elif (Path(model_id) / "medusa_lm_head.pt").exists():
|
||||||
# Try to load as a local Medusa model
|
# Try to load as a local Medusa model
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
medusa_head = Path(model_id).joinpath("medusa_lm_head.pt")
|
medusa_head = Path(model_id) / "medusa_lm_head.pt"
|
||||||
if auto_convert:
|
if auto_convert:
|
||||||
medusa_sf = Path(model_id).joinpath("medusa_lm_head.safetensors")
|
medusa_sf = Path(model_id) / "medusa_lm_head.safetensors"
|
||||||
if not medusa_sf.exists():
|
if not medusa_sf.exists():
|
||||||
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||||
medusa_config = Path(model_id).joinpath("config.json")
|
medusa_config = Path(model_id) / "config.json"
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
|
@ -86,8 +86,8 @@ class FlashLlama(FlashCausalLM):
|
|||||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
medusa_config = str(Path(use_medusa).joinpath("config.json"))
|
medusa_config = str(Path(use_medusa) / "config.json")
|
||||||
medusa_head = str(Path(use_medusa).joinpath("medusa_lm_head.pt"))
|
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
Loading…
Reference in New Issue
Block a user