Fixing revision for the medusa test.

This commit is contained in:
Nicolas Patry 2024-02-26 16:31:40 +00:00
parent e672f976fb
commit bfec09ecc2

View File

@ -124,6 +124,7 @@ def get_model(
use_medusa = None
if "medusa_num_heads" in config_dict:
medusa_model_id = model_id
medusa_revision = revision
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
@ -143,11 +144,11 @@ def get_model(
is_local = Path(medusa_model_id).exists()
if not is_local:
medusa_config = hf_hub_download(
medusa_model_id, revision=revision, filename="config.json"
medusa_model_id, revision=medusa_revision, filename="config.json"
)
hf_hub_download(
medusa_model_id,
revision=revision,
revision=medusa_revision,
filename="medusa_lm_head.safetensors",
)
use_medusa = Path(medusa_config).parent