mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Model_type location.
This commit is contained in:
parent
aceb87cc15
commit
6009dadee3
@ -136,6 +136,7 @@ def get_model(
|
|||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
speculator = None
|
speculator = None
|
||||||
if "medusa_num_heads" in config_dict:
|
if "medusa_num_heads" in config_dict:
|
||||||
@ -178,11 +179,10 @@ def get_model(
|
|||||||
}
|
}
|
||||||
|
|
||||||
method = "medusa"
|
method = "medusa"
|
||||||
elif config_dict["model_type"] == "mlp_speculator":
|
elif model_type == "mlp_speculator":
|
||||||
# TODO make this not hardcoded.
|
|
||||||
mlp_model_id = model_id
|
mlp_model_id = model_id
|
||||||
mlp_revision = revision
|
mlp_revision = revision
|
||||||
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
revision = "main"
|
revision = "main"
|
||||||
speculate_mlp = config_dict["n_predict"]
|
speculate_mlp = config_dict["n_predict"]
|
||||||
if speculate is not None:
|
if speculate is not None:
|
||||||
@ -237,7 +237,6 @@ def get_model(
|
|||||||
if speculate > 0:
|
if speculate > 0:
|
||||||
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
||||||
|
|
||||||
model_type = config_dict.get("model_type", None)
|
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
# TODO: fix how we determine model type for Mamba
|
# TODO: fix how we determine model type for Mamba
|
||||||
if "ssm_cfg" in config_dict:
|
if "ssm_cfg" in config_dict:
|
||||||
|
Loading…
Reference in New Issue
Block a user