Model_type location.

This commit is contained in:
Nicolas Patry 2024-05-13 14:13:07 +00:00
parent aceb87cc15
commit 6009dadee3

View File

@ -136,6 +136,7 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
model_type = config_dict.get("model_type", None)
speculator = None speculator = None
if "medusa_num_heads" in config_dict: if "medusa_num_heads" in config_dict:
@ -178,11 +179,10 @@ def get_model(
} }
method = "medusa" method = "medusa"
elif config_dict["model_type"] == "mlp_speculator": elif model_type == "mlp_speculator":
# TODO make this not hardcoded.
mlp_model_id = model_id mlp_model_id = model_id
mlp_revision = revision mlp_revision = revision
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" model_id = config_dict["base_model_name_or_path"]
revision = "main" revision = "main"
speculate_mlp = config_dict["n_predict"] speculate_mlp = config_dict["n_predict"]
if speculate is not None: if speculate is not None:
@ -237,7 +237,6 @@ def get_model(
if speculate > 0: if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.") logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict.get("model_type", None)
if model_type is None: if model_type is None:
# TODO: fix how we determine model type for Mamba # TODO: fix how we determine model type for Mamba
if "ssm_cfg" in config_dict: if "ssm_cfg" in config_dict: