diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 6b7a894c..e11c7cf9 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -306,6 +306,7 @@ def launcher(event_loop): use_flash_attention: bool = True, disable_grammar_support: bool = False, dtype: Optional[str] = None, + revision: Optional[str] = None, ): port = random.randint(8000, 10_000) @@ -321,6 +322,9 @@ def launcher(event_loop): if dtype is not None: args.append("--dtype") args.append(dtype) + if revision is not None: + args.append("--revision") + args.append(revision) if trust_remote_code: args.append("--trust-remote-code") diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index b74fbe36..4d6c5603 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -154,12 +154,8 @@ def download_weights( import json medusa_head = hf_hub_download( - model_id, revision=revision, filename="medusa_lm_head.pt" + model_id, revision=revision, filename="medusa_lm_head.safetensors" ) - if auto_convert: - medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors") - if not medusa_sf.exists(): - utils.convert_files([Path(medusa_head)], [medusa_sf], []) medusa_config = hf_hub_download( model_id, revision=revision, filename="config.json" )