Fix local load for Medusa

This commit is contained in:
PYNing 2024-01-09 19:42:03 +08:00
parent 630800eed3
commit 7ffe9023da
2 changed files with 47 additions and 7 deletions

View File

@ -198,6 +198,35 @@ def download_weights(
if not extension == ".safetensors" or not auto_convert: if not extension == ".safetensors" or not auto_convert:
raise e raise e
elif Path(model_id).joinpath("medusa_lm_head.pt").exists():
# Try to load as a local Medusa model
try:
import json
medusa_head = Path(model_id).joinpath("medusa_lm_head.pt")
if auto_convert:
medusa_sf = Path(model_id).joinpath("medusa_lm_head.safetensors")
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = Path(model_id).joinpath("config.json")
with open(medusa_config, "r") as f:
config = json.load(f)
model_id = config["base_model_name_or_path"]
revision = "main"
try:
utils.weight_files(model_id, revision, extension)
logger.info(
f"Files for parent {model_id} are already present on the host. "
"Skipping download."
)
return
# Local files not found
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
else: else:
# Try to load as a local PEFT model # Try to load as a local PEFT model
try: try:

View File

@ -71,15 +71,26 @@ class FlashLlama(FlashCausalLM):
from text_generation_server.utils.medusa import MedusaModel from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
import os
from pathlib import Path
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
"WEIGHTS_CACHE_OVERRIDE", None
) is not None
if not is_local_model:
medusa_config = hf_hub_download( medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json" use_medusa, revision=revision, filename="config.json"
) )
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_head = hf_hub_download( medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt" use_medusa, revision=revision, filename="medusa_lm_head.pt"
) )
else:
medusa_config = str(Path(use_medusa).joinpath("config.json"))
medusa_head = str(Path(use_medusa).joinpath("medusa_lm_head.pt"))
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights( weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group [medusa_sf], device, dtype, process_group=self.process_group