From 70ada578b92442d3893451c4920e77bdc372b6d6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Jan 2025 18:01:12 +0100 Subject: [PATCH] check for non-native models --- server/text_generation_server/models/__init__.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a7c8c4a7..f7f7a26e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -16,7 +16,6 @@ from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi from typing import Optional, List, Dict from pathlib import Path -import transformers from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model @@ -385,11 +384,14 @@ def get_model( transformers_causal_lm_class = CausalLM # Fast transformers path - transformers_model_class = getattr( - transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] + transformers_model_class = modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get( + model_type, None ) - - if FLASH_TRANSFORMERS_BACKEND and transformers_model_class._supports_flex_attn: + if ( + FLASH_TRANSFORMERS_BACKEND + and transformers_model_class is not None + and transformers_model_class._supports_flex_attn + ): transformers_causal_lm_class = TransformersFlashCausalLM quantization_config = config_dict.get("quantization_config", None)