mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fix local load for Medusa
This commit is contained in:
parent
630800eed3
commit
7ffe9023da
@ -198,6 +198,35 @@ def download_weights(
|
|||||||
if not extension == ".safetensors" or not auto_convert:
|
if not extension == ".safetensors" or not auto_convert:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
elif Path(model_id).joinpath("medusa_lm_head.pt").exists():
|
||||||
|
# Try to load as a local Medusa model
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
|
||||||
|
medusa_head = Path(model_id).joinpath("medusa_lm_head.pt")
|
||||||
|
if auto_convert:
|
||||||
|
medusa_sf = Path(model_id).joinpath("medusa_lm_head.safetensors")
|
||||||
|
if not medusa_sf.exists():
|
||||||
|
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||||
|
medusa_config = Path(model_id).joinpath("config.json")
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
model_id = config["base_model_name_or_path"]
|
||||||
|
revision = "main"
|
||||||
|
try:
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
logger.info(
|
||||||
|
f"Files for parent {model_id} are already present on the host. "
|
||||||
|
"Skipping download."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
# Local files not found
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Try to load as a local PEFT model
|
# Try to load as a local PEFT model
|
||||||
try:
|
try:
|
||||||
|
@ -71,15 +71,26 @@ class FlashLlama(FlashCausalLM):
|
|||||||
from text_generation_server.utils.medusa import MedusaModel
|
from text_generation_server.utils.medusa import MedusaModel
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
medusa_config = hf_hub_download(
|
from pathlib import Path
|
||||||
use_medusa, revision=revision, filename="config.json"
|
|
||||||
)
|
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
|
||||||
|
"WEIGHTS_CACHE_OVERRIDE", None
|
||||||
|
) is not None
|
||||||
|
|
||||||
|
if not is_local_model:
|
||||||
|
medusa_config = hf_hub_download(
|
||||||
|
use_medusa, revision=revision, filename="config.json"
|
||||||
|
)
|
||||||
|
medusa_head = hf_hub_download(
|
||||||
|
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
medusa_config = str(Path(use_medusa).joinpath("config.json"))
|
||||||
|
medusa_head = str(Path(use_medusa).joinpath("medusa_lm_head.pt"))
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
medusa_head = hf_hub_download(
|
|
||||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
|
||||||
)
|
|
||||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||||
weights = Weights(
|
weights = Weights(
|
||||||
[medusa_sf], device, dtype, process_group=self.process_group
|
[medusa_sf], device, dtype, process_group=self.process_group
|
||||||
|
Loading…
Reference in New Issue
Block a user