mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fix local load for Medusa
This commit is contained in:
parent
630800eed3
commit
7ffe9023da
@ -198,6 +198,35 @@ def download_weights(
|
||||
if not extension == ".safetensors" or not auto_convert:
|
||||
raise e
|
||||
|
||||
elif Path(model_id).joinpath("medusa_lm_head.pt").exists():
|
||||
# Try to load as a local Medusa model
|
||||
try:
|
||||
import json
|
||||
|
||||
medusa_head = Path(model_id).joinpath("medusa_lm_head.pt")
|
||||
if auto_convert:
|
||||
medusa_sf = Path(model_id).joinpath("medusa_lm_head.safetensors")
|
||||
if not medusa_sf.exists():
|
||||
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||
medusa_config = Path(model_id).joinpath("config.json")
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
model_id = config["base_model_name_or_path"]
|
||||
revision = "main"
|
||||
try:
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
logger.info(
|
||||
f"Files for parent {model_id} are already present on the host. "
|
||||
"Skipping download."
|
||||
)
|
||||
return
|
||||
# Local files not found
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
|
||||
else:
|
||||
# Try to load as a local PEFT model
|
||||
try:
|
||||
|
@ -71,15 +71,26 @@ class FlashLlama(FlashCausalLM):
|
||||
from text_generation_server.utils.medusa import MedusaModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
|
||||
"WEIGHTS_CACHE_OVERRIDE", None
|
||||
) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
)
|
||||
medusa_head = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa).joinpath("config.json"))
|
||||
medusa_head = str(Path(use_medusa).joinpath("medusa_lm_head.pt"))
|
||||
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
)
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_head = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
weights = Weights(
|
||||
[medusa_sf], device, dtype, process_group=self.process_group
|
||||
|
Loading…
Reference in New Issue
Block a user