diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 4adf1381f..78b68721f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,7 +1,7 @@ import torch from loguru import logger -from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from typing import Optional @@ -138,10 +138,8 @@ def get_model( trust_remote_code=trust_remote_code, ) - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - model_type = config.model_type + config_dict, _ = PretrainedConfig.get_config_dict(model_id, revision=revision, trust_remote_code=trust_remote_code) + model_type = config_dict["model_type"] if model_type == "gpt_bigcode": if sharded: @@ -201,9 +199,9 @@ def get_model( if model_type in ["RefinedWeb", "RefinedWebModel"]: if sharded: if FLASH_ATTENTION: - if config.alibi or ( - config.model_type == "RefinedWebModel" - and config.n_head_kv != config.n_head + if config_dict.get("alibi", False) or ( + model_type == "RefinedWebModel" + and config_dict.get("multi_query", True) ): raise NotImplementedError("sharded is not supported for this model") return FlashRWSharded( @@ -216,7 +214,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb") ) else: - if FLASH_ATTENTION and not config.alibi: + if FLASH_ATTENTION and not config_dict.get("alibi", False): return FlashRW( model_id, revision, @@ -250,7 +248,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if config.model_type == "opt": + if model_type == "opt": if sharded: return OPTSharded( model_id, @@ -294,7 +292,7 @@ def get_model( model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code ) - auto_map = getattr(config, "auto_map", None) + auto_map = config_dict.get("auto_map", None) if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): return CausalLM(