Fixing revision for the medusa test.

This commit is contained in:
Nicolas Patry 2024-02-26 16:31:40 +00:00
parent e672f976fb
commit bfec09ecc2

View File

@ -124,6 +124,7 @@ def get_model(
use_medusa = None use_medusa = None
if "medusa_num_heads" in config_dict: if "medusa_num_heads" in config_dict:
medusa_model_id = model_id medusa_model_id = model_id
medusa_revision = revision
model_id = config_dict["base_model_name_or_path"] model_id = config_dict["base_model_name_or_path"]
revision = "main" revision = "main"
speculate_medusa = config_dict["medusa_num_heads"] speculate_medusa = config_dict["medusa_num_heads"]
@ -143,11 +144,11 @@ def get_model(
is_local = Path(medusa_model_id).exists() is_local = Path(medusa_model_id).exists()
if not is_local: if not is_local:
medusa_config = hf_hub_download( medusa_config = hf_hub_download(
medusa_model_id, revision=revision, filename="config.json" medusa_model_id, revision=medusa_revision, filename="config.json"
) )
hf_hub_download( hf_hub_download(
medusa_model_id, medusa_model_id,
revision=revision, revision=medusa_revision,
filename="medusa_lm_head.safetensors", filename="medusa_lm_head.safetensors",
) )
use_medusa = Path(medusa_config).parent use_medusa = Path(medusa_config).parent