mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Model_type location.
This commit is contained in:
parent
aceb87cc15
commit
6009dadee3
@ -136,6 +136,7 @@ def get_model(
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
model_type = config_dict.get("model_type", None)
|
||||
|
||||
speculator = None
|
||||
if "medusa_num_heads" in config_dict:
|
||||
@ -178,11 +179,10 @@ def get_model(
|
||||
}
|
||||
|
||||
method = "medusa"
|
||||
elif config_dict["model_type"] == "mlp_speculator":
|
||||
# TODO make this not hardcoded.
|
||||
elif model_type == "mlp_speculator":
|
||||
mlp_model_id = model_id
|
||||
mlp_revision = revision
|
||||
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
model_id = config_dict["base_model_name_or_path"]
|
||||
revision = "main"
|
||||
speculate_mlp = config_dict["n_predict"]
|
||||
if speculate is not None:
|
||||
@ -237,7 +237,6 @@ def get_model(
|
||||
if speculate > 0:
|
||||
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
||||
|
||||
model_type = config_dict.get("model_type", None)
|
||||
if model_type is None:
|
||||
# TODO: fix how we determine model type for Mamba
|
||||
if "ssm_cfg" in config_dict:
|
||||
|
Loading…
Reference in New Issue
Block a user