diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 4d6c5603..a513f5e6 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -194,16 +194,12 @@ def download_weights( if not extension == ".safetensors" or not auto_convert: raise e - elif (Path(model_id) / "medusa_lm_head.pt").exists(): + elif (Path(model_id) / "medusa_lm_head.safetensors").exists(): # Try to load as a local Medusa model try: import json - medusa_head = Path(model_id) / "medusa_lm_head.pt" - if auto_convert: - medusa_sf = Path(model_id) / "medusa_lm_head.safetensors" - if not medusa_sf.exists(): - utils.convert_files([Path(medusa_head)], [medusa_sf], []) + medusa_head = Path(model_id) / "medusa_lm_head.safetensors" medusa_config = Path(model_id) / "config.json" with open(medusa_config, "r") as f: config = json.load(f) diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index b56484f6..a81e659d 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -40,6 +40,7 @@ def _weight_hub_files_from_model_info( and "arguments" not in s.rfilename and "args" not in s.rfilename and "training" not in s.rfilename + and "medusa_lm_head" not in s.rfilename ] @@ -56,6 +57,7 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: and "args" not in f and "adapter" not in f and "training" not in f + and "medusa_lm_head" not in f ] return filenames