From 6009dadee349e41fdfdde72ead28221d340730e4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 13 May 2024 14:13:07 +0000 Subject: [PATCH] Model_type location. --- server/text_generation_server/models/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fcecf8af..eabc8293 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -136,6 +136,7 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) + model_type = config_dict.get("model_type", None) speculator = None if "medusa_num_heads" in config_dict: @@ -178,11 +179,10 @@ def get_model( } method = "medusa" - elif config_dict["model_type"] == "mlp_speculator": - # TODO make this not hardcoded. + elif model_type == "mlp_speculator": mlp_model_id = model_id mlp_revision = revision - model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + model_id = config_dict["base_model_name_or_path"] revision = "main" speculate_mlp = config_dict["n_predict"] if speculate is not None: @@ -237,7 +237,6 @@ def get_model( if speculate > 0: logger.info(f"Using speculation {method} with {speculate} input ids.") - model_type = config_dict.get("model_type", None) if model_type is None: # TODO: fix how we determine model type for Mamba if "ssm_cfg" in config_dict: